Repository: ByChelsea/VAND-APRIL-GAN Branch: master Commit: f13b8a634e04 Files: 90 Total size: 24.2 MB Directory structure: gitextract_7tq5j4qy/ ├── LICENSE ├── README.md ├── data/ │ ├── mvtec.py │ └── visa.py ├── dataset.py ├── exps/ │ └── pretrained/ │ ├── mvtec_pretrained.pth │ └── visa_pretrained.pth ├── few_shot.py ├── loss.py ├── model.py ├── open_clip/ │ ├── __init__.py │ ├── coca_model.py │ ├── constants.py │ ├── factory.py │ ├── generation_utils.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs/ │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── push_to_hf_hub.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ ├── utils.py │ └── version.py ├── prompt_ensemble.py ├── requirements.txt ├── test.py ├── test_few_shot.sh ├── test_zero_shot.sh ├── train.py └── train.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Xuhai Chen Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ [Workshop Link](https://sites.google.com/view/vand-cvpr23/home) | [Challenge Link](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0) | [Report Paper](https://arxiv.org/abs/2305.17382) --- [Xuhai Chen](https://scholar.google.com.hk/citations?user=LU4etJ0AAAAJ&hl=zh-CN&authuser=1) · [Yue Han](https://scholar.google.com/citations?hl=en&user=08E500gAAAAJ&view_op=list_works&gmla=AHoSzlVzTXnclaPp9h1g8xAZQBrsxdFXvhunMA3AmRm_GSLnZA1956xavx6hmPaCFCysonsXeTQyhB_cokdUFacUc5HBunMPW-uOtLZLTTufiZiHB6hAVgr9l7cJ_UHKeQ) · [Jiangning Zhang](https://zhangzjn.github.io/) This repository contains the official PyTorch implementation of [Zero-/Few-shot Anomaly Classification and Segmentation Method](https://arxiv.org/abs/2305.17382) used in the [CVPR 2023 VAND Challenge](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0), which can be viewd as an improved version of [WinCLIP](https://arxiv.org/abs/2303.14814). We achieve **Winner** in the Zero-shot Track and **Honorable Mentions** in the Few-shot Track. Model Structure **Results on the Challenge official test set** Model Structure ## Installation - Prepare experimental environments ```shell pip install -r requirements.txt ``` ## Dataset Preparation ### MVTec AD - Download and extract [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) into `data/mvtec` - run`python data/mvtec.py` to obtain `data/mvtec/meta.json` ``` data ├── mvtec ├── meta.json ├── bottle ├── train ├── good ├── 000.png ├── test ├── good ├── 000.png ├── anomaly1 ├── 000.png ├── ground_truth ├── anomaly1 ├── 000.png ``` ### VisA - Download and extract [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar) into `data/visa` - run`python data/visa.py` to obtain `data/visa/meta.json` ``` data ├── visa ├── meta.json ├── candle ├── Data ├── Images ├── Anomaly ├── 000.JPG ├── Normal ├── 0000.JPG ├── Masks ├── Anomaly ├── 000.png ``` ## Train Set parameters in `train.sh`. - `train_data_path`: the path to the training dataset - `dataset`: name of the training dataset, optional: mvtec, visa - `model`: the CLIP model - `pretrained`: the pretrained weights - `features_list`: features to be mapped into the joint embedding space - `image_size`: the size of the images inputted into the CLIP model - `aug_rate`: the probability of stitching images (only for mvtec) Then run the following command ```shell sh train.sh ``` ## Test ### Pretrained Models We provide our pre-trained models in `exps/pretrained`, where `mvtec_pretrained.pth` represents the model trained on the MVTec AD dataset and `visa_pretrained.pth` represents the model trained on the VisA dataset. Set parameters in `test_zero_shot.sh`. - `data_path`: the path to the test dataset - `dataset`: name of the test dataset, optional: mvtec, visa - `checkpoint_path`: the path to the test model Then, run the following command to test them in the zero-shot setting: ```shell sh test_zero_shot.sh ``` Set parameters in `test_few_shot.sh`. - `data_path`: the path to the test dataset - `dataset`: name of the test dataset, optional: mvtec, visa - `checkpoint_path`: the path to the test model - `k_shot`: different number of reference images Then, run the following command to test them in the few-shot setting: ```shell sh test_few_shot.sh ``` ### Zero-shot Setting Set parameters in `test_zero_shot.sh`. - `data_path`: the path to the test dataset - `dataset`: name of the test dataset, optional: mvtec, visa - `checkpoint_path`: the path to the test model - `model`: the CLIP model - `pretrained`: the pretrained weights - `features_list`: features to be mapped into the joint embedding space - `image_size`: the size of the images inputted into the CLIP model - `mode`: zero shot or few shot Then run the following command ```shell sh test_zero_shot.sh ``` ### Few-shot Setting Set parameters in `test_few_shot.sh`. - `data_path`: the path to the test dataset - `dataset`: name of the test dataset, optional: mvtec, visa - `checkpoint_path`: the path to the test model - `model`: the CLIP model - `pretrained`: the pretrained weights - `features_list`: features to be mapped into the joint embedding space - `few_shot_features`: features stored in the memory banks - `image_size`: the size of the images inputted into the CLIP model - `mode`: zero shot or few shot - `k_shot`: different number of reference images - `seed`: the random seed Then run the following command ```shell sh test_few_shot.sh ``` ## Citation If our work is helpful for your research, please consider citing: ``` @article{chen2023zero, title={A Zero-/Few-Shot Anomaly Classification and Segmentation Method for CVPR 2023 VAND Workshop Challenge Tracks 1\&2: 1st Place on Zero-shot AD and 4th Place on Few-shot AD}, author={Chen, Xuhai and Han, Yue and Zhang, Jiangning}, journal={arXiv preprint arXiv:2305.17382}, year={2023} } ``` ## Acknowledgements We thank [WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation](https://arxiv.org/abs/2303.14814) for providing assistance for our research. ================================================ FILE: data/mvtec.py ================================================ import os import json class MVTecSolver(object): CLSNAMES = [ 'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', ] def __init__(self, root='data/mvtec'): self.root = root self.meta_path = f'{root}/meta.json' def run(self): info = dict(train={}, test={}) for cls_name in self.CLSNAMES: cls_dir = f'{self.root}/{cls_name}' for phase in ['train', 'test']: cls_info = [] species = os.listdir(f'{cls_dir}/{phase}') for specie in species: is_abnormal = True if specie not in ['good'] else False img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None img_names.sort() mask_names.sort() if mask_names is not None else None for idx, img_name in enumerate(img_names): info_img = dict( img_path=f'{cls_name}/{phase}/{specie}/{img_name}', mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', cls_name=cls_name, specie_name=specie, anomaly=1 if is_abnormal else 0, ) cls_info.append(info_img) info[phase][cls_name] = cls_info with open(self.meta_path, 'w') as f: f.write(json.dumps(info, indent=4) + "\n") if __name__ == '__main__': runner = MVTecSolver(root='data/mvtec') runner.run() ================================================ FILE: data/visa.py ================================================ import os import json import pandas as pd class VisASolver(object): CLSNAMES = [ 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', 'pcb4', 'pipe_fryum', ] def __init__(self, root='data/visa'): self.root = root self.meta_path = f'{root}/meta.json' self.phases = ['train', 'test'] self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) def run(self): columns = self.csv_data.columns # [object, split, label, image, mask] info = {phase: {} for phase in self.phases} for cls_name in self.CLSNAMES: cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] for phase in self.phases: cls_info = [] cls_data_phase = cls_data[cls_data[columns[1]] == phase] cls_data_phase.index = list(range(len(cls_data_phase))) for idx in range(cls_data_phase.shape[0]): data = cls_data_phase.loc[idx] is_abnormal = True if data[2] == 'anomaly' else False info_img = dict( img_path=data[3], mask_path=data[4] if is_abnormal else '', cls_name=cls_name, specie_name='', anomaly=1 if is_abnormal else 0, ) cls_info.append(info_img) info[phase][cls_name] = cls_info with open(self.meta_path, 'w') as f: f.write(json.dumps(info, indent=4) + "\n") if __name__ == '__main__': runner = VisASolver(root='data/visa') runner.run() ================================================ FILE: dataset.py ================================================ import torch.utils.data as data import json import random from PIL import Image import numpy as np import torch import os class VisaDataset(data.Dataset): def __init__(self, root, transform, target_transform, mode='test', k_shot=0, save_dir=None, obj_name=None): self.root = root self.transform = transform self.target_transform = target_transform self.data_all = [] meta_info = json.load(open(f'{self.root}/meta.json', 'r')) name = self.root.split('/')[-1] meta_info = meta_info[mode] if mode == 'train': self.cls_names = [obj_name] save_dir = os.path.join(save_dir, 'k_shot.txt') else: self.cls_names = list(meta_info.keys()) for cls_name in self.cls_names: if mode == 'train': data_tmp = meta_info[cls_name] indices = torch.randint(0, len(data_tmp), (k_shot,)) for i in range(len(indices)): self.data_all.append(data_tmp[indices[i]]) with open(save_dir, "a") as f: f.write(data_tmp[indices[i]]['img_path'] + '\n') else: self.data_all.extend(meta_info[cls_name]) self.length = len(self.data_all) def __len__(self): return self.length def get_cls_names(self): return self.cls_names def __getitem__(self, index): data = self.data_all[index] img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \ data['specie_name'], data['anomaly'] img = Image.open(os.path.join(self.root, img_path)) if anomaly == 0: img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') else: img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0 img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') img = self.transform(img) if self.transform is not None else img img_mask = self.target_transform( img_mask) if self.target_transform is not None and img_mask is not None else img_mask img_mask = [] if img_mask is None else img_mask return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly, 'img_path': os.path.join(self.root, img_path)} class MVTecDataset(data.Dataset): def __init__(self, root, transform, target_transform, aug_rate, mode='test', k_shot=0, save_dir=None, obj_name=None): self.root = root self.transform = transform self.target_transform = target_transform self.aug_rate = aug_rate self.data_all = [] meta_info = json.load(open(f'{self.root}/meta.json', 'r')) name = self.root.split('/')[-1] meta_info = meta_info[mode] if mode == 'train': self.cls_names = [obj_name] save_dir = os.path.join(save_dir, 'k_shot.txt') else: self.cls_names = list(meta_info.keys()) for cls_name in self.cls_names: if mode == 'train': data_tmp = meta_info[cls_name] indices = torch.randint(0, len(data_tmp), (k_shot,)) for i in range(len(indices)): self.data_all.append(data_tmp[indices[i]]) with open(save_dir, "a") as f: f.write(data_tmp[indices[i]]['img_path'] + '\n') else: self.data_all.extend(meta_info[cls_name]) self.length = len(self.data_all) def __len__(self): return self.length def get_cls_names(self): return self.cls_names def combine_img(self, cls_name): img_paths = os.path.join(self.root, cls_name, 'test') img_ls = [] mask_ls = [] for i in range(4): defect = os.listdir(img_paths) random_defect = random.choice(defect) files = os.listdir(os.path.join(img_paths, random_defect)) random_file = random.choice(files) img_path = os.path.join(img_paths, random_defect, random_file) mask_path = os.path.join(self.root, cls_name, 'ground_truth', random_defect, random_file[:3] + '_mask.png') img = Image.open(img_path) img_ls.append(img) if random_defect == 'good': img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') else: img_mask = np.array(Image.open(mask_path).convert('L')) > 0 img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') mask_ls.append(img_mask) # image image_width, image_height = img_ls[0].size result_image = Image.new("RGB", (2 * image_width, 2 * image_height)) for i, img in enumerate(img_ls): row = i // 2 col = i % 2 x = col * image_width y = row * image_height result_image.paste(img, (x, y)) # mask result_mask = Image.new("L", (2 * image_width, 2 * image_height)) for i, img in enumerate(mask_ls): row = i // 2 col = i % 2 x = col * image_width y = row * image_height result_mask.paste(img, (x, y)) return result_image, result_mask def __getitem__(self, index): data = self.data_all[index] img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \ data['specie_name'], data['anomaly'] random_number = random.random() if random_number < self.aug_rate: img, img_mask = self.combine_img(cls_name) else: img = Image.open(os.path.join(self.root, img_path)) if anomaly == 0: img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') else: img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0 img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') # transforms img = self.transform(img) if self.transform is not None else img img_mask = self.target_transform( img_mask) if self.target_transform is not None and img_mask is not None else img_mask img_mask = [] if img_mask is None else img_mask return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly, 'img_path': os.path.join(self.root, img_path)} ================================================ FILE: exps/pretrained/mvtec_pretrained.pth ================================================ [File too large to display: 12.0 MB] ================================================ FILE: exps/pretrained/visa_pretrained.pth ================================================ [File too large to display: 12.0 MB] ================================================ FILE: few_shot.py ================================================ import torch from dataset import VisaDataset, MVTecDataset def memory(model_name, model, obj_list, dataset_dir, save_path, preprocess, transform, k_shot, few_shot_features, dataset_name, device): mem_features = {} for obj in obj_list: if dataset_name == 'mvtec': data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform, aug_rate=-1, mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj) else: data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform, mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj) dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False) features = [] for items in dataloader: image = items['img'].to(device) with torch.no_grad(): image_features, patch_tokens = model.encode_image(image, few_shot_features) if 'ViT' in model_name: patch_tokens = [p[0, 1:, :] for p in patch_tokens] else: patch_tokens = [p[0].view(p.shape[1], -1).permute(1, 0).contiguous() for p in patch_tokens] features.append(patch_tokens) mem_features[obj] = [torch.cat( [features[j][i] for j in range(len(features))], dim=0) for i in range(len(features[0]))] return mem_features ================================================ FILE: loss.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from math import exp class FocalLoss(nn.Module): """ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' Focal_Loss= -1*alpha*(1-pt)*log(pt) :param alpha: (tensor) 3D or 4D the scalar factor for this criterion :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more focus on hard misclassified example :param smooth: (float,double) smooth value when cross entropy :param balance_index: (int) balance class index, should be specific when alpha is float :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. """ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): super(FocalLoss, self).__init__() self.apply_nonlin = apply_nonlin self.alpha = alpha self.gamma = gamma self.balance_index = balance_index self.smooth = smooth self.size_average = size_average if self.smooth is not None: if self.smooth < 0 or self.smooth > 1.0: raise ValueError('smooth value should be in [0,1]') def forward(self, logit, target): if self.apply_nonlin is not None: logit = self.apply_nonlin(logit) num_class = logit.shape[1] if logit.dim() > 2: # N,C,d1,d2 -> N,C,m (m=d1*d2*...) logit = logit.view(logit.size(0), logit.size(1), -1) logit = logit.permute(0, 2, 1).contiguous() logit = logit.view(-1, logit.size(-1)) target = torch.squeeze(target, 1) target = target.view(-1, 1) alpha = self.alpha if alpha is None: alpha = torch.ones(num_class, 1) elif isinstance(alpha, (list, np.ndarray)): assert len(alpha) == num_class alpha = torch.FloatTensor(alpha).view(num_class, 1) alpha = alpha / alpha.sum() elif isinstance(alpha, float): alpha = torch.ones(num_class, 1) alpha = alpha * (1 - self.alpha) alpha[self.balance_index] = self.alpha else: raise TypeError('Not support alpha type') if alpha.device != logit.device: alpha = alpha.to(logit.device) idx = target.cpu().long() one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() one_hot_key = one_hot_key.scatter_(1, idx, 1) if one_hot_key.device != logit.device: one_hot_key = one_hot_key.to(logit.device) if self.smooth: one_hot_key = torch.clamp( one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) pt = (one_hot_key * logit).sum(1) + self.smooth logpt = pt.log() gamma = self.gamma alpha = alpha[idx] alpha = torch.squeeze(alpha) loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt if self.size_average: loss = loss.mean() return loss class BinaryDiceLoss(nn.Module): def __init__(self): super(BinaryDiceLoss, self).__init__() def forward(self, input, targets): # 获取每个批次的大小 N N = targets.size()[0] # 平滑变量 smooth = 1 # 将宽高 reshape 到同一纬度 input_flat = input.view(N, -1) targets_flat = targets.view(N, -1) # 计算交集 intersection = input_flat * targets_flat N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth) # 计算一个批次中平均每张图的损失 loss = 1 - N_dice_eff.sum() / N return loss ================================================ FILE: model.py ================================================ from torch import Tensor, nn import torch from torch.nn import functional as F class LinearLayer(nn.Module): def __init__(self, dim_in, dim_out, k, model): super(LinearLayer, self).__init__() if 'ViT' in model: self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)]) else: self.fc = nn.ModuleList([nn.Linear(dim_in * 2 ** (i + 2), dim_out) for i in range(k)]) def forward(self, tokens): for i in range(len(tokens)): if len(tokens[i].shape) == 3: tokens[i] = self.fc[i](tokens[i][:, 1:, :]) else: B, C, H, W = tokens[i].shape tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous()) return tokens ================================================ FILE: open_clip/__init__.py ================================================ from .coca_model import CoCa from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub from .tokenizer import SimpleTokenizer, tokenize, decode from .transform import image_transform, AugmentationCfg ================================================ FILE: open_clip/coca_model.py ================================================ from typing import Optional import torch from torch import nn from torch.nn import functional as F import numpy as np from dataclasses import dataclass from .transformer import ( LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower try: from transformers import ( BeamSearchScorer, LogitsProcessorList, TopPLogitsWarper, TopKLogitsWarper, RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, StoppingCriteriaList ) GENERATION_TYPES = { "top_k": TopKLogitsWarper, "top_p": TopPLogitsWarper, "beam_search": "beam_search" } _has_transformers = True except ImportError as e: GENERATION_TYPES = { "top_k": None, "top_p": None, "beam_search": "beam_search" } _has_transformers = False @dataclass class MultimodalCfg(CLIPTextCfg): mlp_ratio: int = 4 dim_head: int = 64 heads: int = 8 n_queries: int = 256 attn_pooler_heads: int = 8 def _build_text_decoder_tower( embed_dim, multimodal_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, width=multimodal_cfg.width, heads=multimodal_cfg.heads, layers=multimodal_cfg.layers, ls_init_value=multimodal_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return decoder class CoCa(nn.Module): def __init__( self, embed_dim, multimodal_cfg: MultimodalCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, pad_id: int = 0, ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg self.text = _build_text_tower( embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) vocab_size = ( text_cfg.vocab_size # for hf models if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else text_cfg.vocab_size ) self.visual = _build_vision_tower( embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.text_decoder = _build_text_decoder_tower( vocab_size, multimodal_cfg=multimodal_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.pad_id = pad_id @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) self.text_decoder.set_grad_checkpointing(enable) # def _encode_image(self, images, out_layers, normalize=True): # image_latent, tokens_embs = self.visual(images, out_layers) # image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent # return image_latent, tokens_embs def _encode_image(self, images, out_layers, normalize=True): image_latent = self.visual(images, out_layers) return image_latent def _encode_text(self, text, normalize=True, embed_cls=True): text = text[:, :-1] if embed_cls else text # make space for CLS token text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb # def encode_image(self, images, out_layers, normalize=True): # image_latent, _ = self._encode_image(images, out_layers, normalize=normalize) # return image_latent def encode_image(self, images, out_layers, normalize=True): image_latent = self._encode_image(images, out_layers, normalize=normalize) return image_latent def encode_text(self, text, normalize=True, embed_cls=True): text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) return text_latent def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] logits = self.text_decoder(image_embs, token_embs) return { "image_features": image_latent, "text_features": text_latent, "logits": logits, "labels": labels, "logit_scale": self.logit_scale.exp() } def generate( self, image, text=None, seq_len=30, max_seq_len=77, temperature=1., generation_type="beam_search", top_p=0.1, # keep tokens in the 1 - top_p quantile top_k=1, # keeps the top_k most probable tokens pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, fixed_output_length=False # if True output.shape == (batch_size, seq_len) ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" with torch.no_grad(): sot_token_id = 49406 if sot_token_id is None else sot_token_id eos_token_id = 49407 if eos_token_id is None else eos_token_id pad_token_id = self.pad_id if pad_token_id is None else pad_token_id logit_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(min_seq_len, eos_token_id), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] stopping_criteria = StoppingCriteriaList( stopping_criteria ) device = image.device if generation_type == "beam_search": output = self._generate_beamsearch( image_inputs = image, pad_token_id=pad_token_id, eos_token_id=eos_token_id, sot_token_id=sot_token_id, num_beams=num_beams, num_beam_groups=num_beam_groups, min_seq_len=min_seq_len, stopping_criteria=stopping_criteria, logit_processor=logit_processor, ) if fixed_output_length and output.shape[1] < seq_len: return torch.cat( (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), dim=1 ) return output elif generation_type == "top_p": logit_warper = GENERATION_TYPES[generation_type](top_p) elif generation_type == "top_k": logit_warper = GENERATION_TYPES[generation_type](top_k) else: raise ValueError( f"generation_type has to be one of " f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." ) image_latent, image_embs = self._encode_image(image) if text is None: text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id was_training = self.training num_dims = len(text.shape) if num_dims == 1: text = text[None, :] cur_len = text.shape[1] self.eval() out = text while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id if mask.all(): if not fixed_output_length: break else: logits = logits[~mask, :] filtered_logits = logit_processor(x[~mask, :], logits) filtered_logits = logit_warper(x[~mask, :], filtered_logits) probs = F.softmax(filtered_logits / temperature, dim=-1) if (cur_len + 1 == seq_len): sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id else: sample[~mask, :] = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) cur_len += 1 if stopping_criteria(out, None): break if num_dims == 1: out = out.squeeze(0) self.train(was_training) return out def _generate_beamsearch( self, image_inputs, pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, logit_processor=None, logit_warper=None, ): device = image_inputs.device batch_size = image_inputs.shape[0] image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) image_latent, image_embs = self._encode_image(image_inputs) input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) input_ids = input_ids * sot_token_id beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=device, num_beam_groups=num_beam_groups, ) # instantiate logits processors logits_processor = ( LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) if logit_processor is None else logit_processor ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups batch_beam_size, cur_len = input_ids.shape beam_indices = None if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) while True: # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) outputs = self( model_inputs['images'], model_inputs['text'], embed_cls=False, image_latent=image_latent, image_embs=image_embs ) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of currentg group only next_token_logits = outputs['logits'][batch_group_indices, -1, :] vocab_size = next_token_logits.shape[-1] next_token_scores_processed = logits_processor( group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, None): break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, ) return sequence_outputs['sequences'] def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): if past: input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = None return { "text": input_ids, "images": image_inputs, "past_key_values": past, "position_ids": position_ids, "attention_mask": attention_mask, } ================================================ FILE: open_clip/constants.py ================================================ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) ================================================ FILE: open_clip/factory.py ================================================ import json import logging import os import pathlib import re import numpy as np from copy import deepcopy from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform, AugmentationCfg from .tokenizer import HFTokenizer, tokenize HF_HUB_PREFIX = 'hf-hub:' _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS config_ext = ('.json',) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: config_files.extend(config_path.glob(f'*{ext}')) for cf in config_files: with open(cf, 'r') as f: model_cfg = json.load(f) if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): _MODEL_CONFIGS[cf.stem] = model_cfg _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} _rescan_model_configs() # initial populate of model config registry def list_models(): """ enumerate available model architectures based on config files """ return list(_MODEL_CONFIGS.keys()) def add_model_config(path): """ add model config path or file and update registry """ if not isinstance(path, Path): path = Path(path) _MODEL_CONFIG_PATHS.append(path) _rescan_model_configs() def get_model_config(model_name): if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None def get_tokenizer(model_name): if model_name.startswith(HF_HUB_PREFIX): tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) else: config = get_model_config(model_name) tokenizer = HFTokenizer( config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize return tokenizer def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'): state_dict = {k[7:]: v for k, v in state_dict.items()} return state_dict def load_checkpoint(model, checkpoint_path, strict=True): state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys def create_model( model_name: str, img_size: int, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, pretrained_image: bool = False, pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, ): has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) pretrained_cfg = config['preprocess_cfg'] model_cfg = config['model_cfg'] else: model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names checkpoint_path = None pretrained_cfg = {} model_cfg = None if isinstance(device, str): device = torch.device(device) if pretrained and pretrained.lower() == 'openai': logging.info(f'Loading pretrained {model_name} from OpenAI.') model_cfg = model_cfg or get_model_config(model_name) if model_cfg['vision_cfg']['image_size'] != img_size: model_cfg['vision_cfg']['image_size'] = img_size cast_dtype = get_cast_dtype(precision) model_pre = load_openai_model( model_name, precision=precision, device=device, jit=jit, cache_dir=cache_dir, ) state_dict = model_pre.state_dict() # to always output dict even if it is clip if output_dict and hasattr(model_pre, "output_dict"): model_pre.output_dict = True model = CLIP(**model_cfg, cast_dtype=cast_dtype) ### for resnet if not hasattr(model.visual, 'grid_size'): model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1)) resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=True) model.to(device=device) if precision in ("fp16", "bf16"): convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) # set image / mean metadata from pretrained_cfg if available, or use default model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD # to always output dict even if it is clip if output_dict and hasattr(model, "output_dict"): model.output_dict = True if jit: model = torch.jit.script(model) else: model = load_openai_model( model_name, precision=precision, device=device, jit=jit, cache_dir=cache_dir, ) # to always output dict even if it is clip if output_dict and hasattr(model, "output_dict"): model.output_dict = True else: model_cfg = model_cfg or get_model_config(model_name) model_cfg['vision_cfg']['image_size'] = img_size if model_cfg is not None: logging.info(f'Loaded {model_name} model config.') pass else: logging.error(f'Model config for {model_name} not found; available models {list_models()}.') raise RuntimeError(f'Model config for {model_name} not found.') if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models model_cfg["quick_gelu"] = True if force_patch_dropout is not None: # override the default patch dropout value model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout if force_image_size is not None: # override model config's image size model_cfg["vision_cfg"]["image_size"] = force_image_size if pretrained_image: if 'timm_model_name' in model_cfg.get('vision_cfg', {}): # pretrained weight loading for timm models set via vision_cfg model_cfg['vision_cfg']['timm_model_pretrained'] = True else: assert False, 'pretrained image towers currently only supported for timm models' cast_dtype = get_cast_dtype(precision) is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model if custom_text: if is_hf_model: model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf if "coca" in model_name: model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_loaded = False if pretrained: checkpoint_path = '' pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) elif os.path.exists(pretrained): checkpoint_path = pretrained if checkpoint_path: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path) else: error_str = ( f'Pretrained weights ({pretrained}) not found for model {model_name}.' f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') logging.warning(error_str) raise RuntimeError(error_str) pretrained_loaded = True elif has_hf_hub_prefix: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path) pretrained_loaded = True if require_pretrained and not pretrained_loaded: # callers of create_model_from_pretrained always expect pretrained weights raise RuntimeError( f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') model.to(device=device) if precision in ("fp16", "bf16"): convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) # set image / mean metadata from pretrained_cfg if available, or use default model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD # to always output dict even if it is clip if output_dict and hasattr(model, "output_dict"): model.output_dict = True if jit: model = torch.jit.script(model) return model def create_loss(args): if args.distill: return DistillClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) elif "coca" in args.model.lower(): return CoCaLoss( caption_loss_weight=args.coca_caption_loss_weight, clip_loss_weight=args.coca_contrastive_loss_weight, local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) def create_model_and_transforms( model_name: str, img_size: int, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, pretrained_image: bool = False, pretrained_hf: bool = True, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, ): model = create_model( model_name, img_size, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_patch_dropout=force_patch_dropout, force_image_size=force_image_size, pretrained_image=pretrained_image, pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, ) image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) preprocess_train = image_transform( model.visual.image_size, is_train=True, mean=image_mean, std=image_std, aug_cfg=aug_cfg, ) preprocess_val = image_transform( model.visual.image_size, is_train=False, mean=image_mean, std=image_std, ) return model, preprocess_train, preprocess_val def create_model_from_pretrained( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, return_transform: bool = True, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, ): model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_image_size=force_image_size, cache_dir=cache_dir, require_pretrained=True, ) if not return_transform: return model image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) preprocess = image_transform( model.visual.image_size, is_train=False, mean=image_mean, std=image_std, ) return model, preprocess ================================================ FILE: open_clip/generation_utils.py ================================================ ================================================ FILE: open_clip/hf_configs.py ================================================ # HF architecture dict: arch_dict = { # https://huggingface.co/docs/transformers/model_doc/roberta#roberta "roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig "xlm-roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 "mt5": { "config_names": { # unlimited seqlen # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 "context_length": "", "vocab_size": "vocab_size", "width": "d_model", "heads": "num_heads", "layers": "num_layers", "layer_attr": "block", "token_embeddings_attr": "embed_tokens" }, "pooler": "mean_pooler", }, } ================================================ FILE: open_clip/hf_model.py ================================================ """ huggingface model adapter Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. """ import re import torch import torch.nn as nn from torch import TensorType try: import transformers from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ BaseModelOutputWithPoolingAndCrossAttentions except ImportError as e: transformers = None class BaseModelOutput: pass class PretrainedConfig: pass from .hf_configs import arch_dict # utils def _camel2snake(s): return re.sub(r'(? torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = gather_features( image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T logits_per_text = logit_scale * text_features @ all_image_features.T else: logits_per_image = logit_scale * all_image_features @ all_text_features.T logits_per_text = logits_per_image.T else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale, output_dict=False): device = image_features.device logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) labels = self.get_ground_truth(device, logits_per_image.shape[0]) total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 return {"contrastive_loss": total_loss} if output_dict else total_loss class CoCaLoss(ClipLoss): def __init__( self, caption_loss_weight, clip_loss_weight, pad_id=0, # pad_token for open_clip custom tokenizer local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1, use_horovod=False, ): super().__init__( local_loss=local_loss, gather_with_grad=gather_with_grad, cache_labels=cache_labels, rank=rank, world_size=world_size, use_horovod=use_horovod ) self.clip_loss_weight = clip_loss_weight self.caption_loss_weight = caption_loss_weight self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss caption_loss = self.caption_loss( logits.permute(0, 2, 1), labels, ) caption_loss = caption_loss * self.caption_loss_weight if output_dict: return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} return clip_loss, caption_loss class DistillClipLoss(ClipLoss): def dist_loss(self, teacher_logits, student_logits): return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) def forward( self, image_features, text_features, logit_scale, dist_image_features, dist_text_features, dist_logit_scale, output_dict=False, ): logits_per_image, logits_per_text = \ self.get_logits(image_features, text_features, logit_scale) dist_logits_per_image, dist_logits_per_text = \ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) contrastive_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 distill_loss = ( self.dist_loss(dist_logits_per_image, logits_per_image) + self.dist_loss(dist_logits_per_text, logits_per_text) ) / 2 if output_dict: return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} return contrastive_loss, distill_loss ================================================ FILE: open_clip/model.py ================================================ """ CLIP Model Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ from dataclasses import dataclass import logging import math from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer from .utils import to_2tuple @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer n_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth output_tokens: bool = False @dataclass class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 width: int = 512 heads: int = 8 layers: int = 12 ls_init_value: Optional[float] = None # layer scale initial value hf_model_name: str = None hf_tokenizer_name: str = None hf_model_pretrained: bool = True proj: str = 'mlp' pooler_type: str = 'mean_pooler' embed_cls: bool = False pad_id: int = 0 output_tokens: bool = False def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more # memory efficient in recent PyTorch releases (>= 1.10). # NOTE: timm models always use native GELU regardless of quick_gelu flag. act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, input_patchnorm=vision_cfg.input_patchnorm, global_average_pool=vision_cfg.global_average_pool, attentional_pool=vision_cfg.attentional_pool, n_queries=vision_cfg.n_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return visual def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, pretrained=text_cfg.hf_model_pretrained, output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, output_tokens=text_cfg.output_tokens, pad_id=text_cfg.pad_id, act_layer=act_layer, norm_layer=norm_layer, ) return text class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable def encode_image(self, image, out_layers, normalize: bool = False): features = self.visual(image, out_layers) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) if self.output_dict: return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } return image_features, text_features, self.logit_scale.exp() class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): self.text.lock(unlocked_layers, freeze_layer_norm) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) if self.output_dict: return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } return image_features, text_features, self.logit_scale.exp() def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.to(dtype) for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) text_cfg = CLIPTextCfg( context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers, ) model = CLIP( embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict flag = 1 old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None: flag = 0 old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img if flag: state_dict['visual.positional_embedding'] = new_pos_embed else: state_dict['visual.attnpool.positional_embedding'] = new_pos_embed ================================================ FILE: open_clip/model_configs/RN101-quickgelu.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 23, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/RN101.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 23, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/RN50-quickgelu.json ================================================ { "embed_dim": 1024, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 6, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/RN50.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 6, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/RN50x16.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 384, "layers": [ 6, 8, 18, 8 ], "width": 96, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/RN50x4.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 288, "layers": [ 4, 6, 10, 6 ], "width": 80, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/RN50x64.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 448, "layers": [ 3, 15, 36, 10 ], "width": 128, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-B-16-plus-240.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 240, "layers": 12, "width": 896, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-B-16-plus.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 224, "layers": 12, "width": 896, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-B-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-B-32-plus-256.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 256, "layers": 12, "width": 896, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-B-32-quickgelu.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-H-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: open_clip/model_configs/ViT-H-16.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: open_clip/model_configs/ViT-L-14-280.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 280, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-L-14-336.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 336, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-L-14.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-L-16-320.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 320, "layers": 24, "width": 1024, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-L-16.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-M-16-alt.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 16, "ls_init_value": 1e-4 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-M-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-M-32-alt.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-M-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-S-16-alt.json ================================================ { "embed_dim": 256, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 256, "heads": 4, "layers": 10 } } ================================================ FILE: open_clip/model_configs/ViT-S-16.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-S-32-alt.json ================================================ { "embed_dim": 256, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 256, "heads": 4, "layers": 10 } } ================================================ FILE: open_clip/model_configs/ViT-S-32.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: open_clip/model_configs/ViT-bigG-14.json ================================================ { "embed_dim": 1280, "vision_cfg": { "image_size": 224, "layers": 48, "width": 1664, "head_width": 104, "mlp_ratio": 4.9231, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1280, "heads": 20, "layers": 32 } } ================================================ FILE: open_clip/model_configs/ViT-e-14.json ================================================ { "embed_dim": 1280, "vision_cfg": { "image_size": 224, "layers": 56, "width": 1792, "head_width": 112, "mlp_ratio": 8.5715, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1280, "heads": 20, "layers": 36 } } ================================================ FILE: open_clip/model_configs/ViT-g-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 40, "width": 1408, "head_width": 88, "mlp_ratio": 4.3637, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: open_clip/model_configs/coca_ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32, "attentional_pool": true, "attn_pooler_heads": 8, "output_tokens": true }, "text_cfg": { "context_length": 76, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "embed_cls": true, "output_tokens": true }, "multimodal_cfg": { "context_length": 76, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "attn_pooler_heads": 8 }, "custom_text": true } ================================================ FILE: open_clip/model_configs/coca_ViT-L-14.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 14, "attentional_pool": true, "attn_pooler_heads": 8, "output_tokens": true }, "text_cfg": { "context_length": 76, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "embed_cls": true, "output_tokens": true }, "multimodal_cfg": { "context_length": 76, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "attn_pooler_heads": 12 }, "custom_text": true } ================================================ FILE: open_clip/model_configs/coca_base.json ================================================ { "embed_dim": 512, "multimodal_cfg": { "width": 768, "context_length": 76, "vocab_size": 64000, "mlp_ratio": 4, "layers": 12, "dim_head": 64, "heads": 12, "n_queries": 256, "attn_pooler_heads": 8 }, "vision_cfg": { "image_size": 288, "layers": 12, "width": 768, "patch_size": 18, "output_tokens": true }, "text_cfg": { "context_length": 76, "vocab_size": 64000, "layers": 12, "heads": 12, "width": 768, "embed_cls": true, "output_tokens": true }, "custom_text": true } ================================================ FILE: open_clip/model_configs/coca_roberta-ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32, "output_tokens": true }, "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", "proj": "linear", "width": 768, "output_tokens": true }, "multimodal_cfg": { "context_length": 76, "width": 768, "heads": 8, "layers": 12 }, "custom_text": true } ================================================ FILE: open_clip/model_configs/convnext_base.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/convnext_base_w.json ================================================ { "embed_dim": 640, "vision_cfg": { "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/convnext_base_w_320.json ================================================ { "embed_dim": 640, "vision_cfg": { "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 320 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/convnext_large.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "convnext_large", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: open_clip/model_configs/convnext_large_d.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "convnext_large", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "mlp", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 16 } } ================================================ FILE: open_clip/model_configs/convnext_large_d_320.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "convnext_large", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "mlp", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 320 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 16 } } ================================================ FILE: open_clip/model_configs/convnext_small.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "convnext_small", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/convnext_tiny.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_tiny", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/convnext_xlarge.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_xlarge", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 20 } } ================================================ FILE: open_clip/model_configs/convnext_xxlarge.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_xxlarge", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: open_clip/model_configs/convnext_xxlarge_320.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_xxlarge", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 320 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: open_clip/model_configs/mt5-base-ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "google/mt5-base", "hf_tokenizer_name": "google/mt5-base", "proj": "mlp", "pooler_type": "mean_pooler" } } ================================================ FILE: open_clip/model_configs/mt5-xl-ViT-H-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "hf_model_name": "google/mt5-xl", "hf_tokenizer_name": "google/mt5-xl", "proj": "mlp", "pooler_type": "mean_pooler" } } ================================================ FILE: open_clip/model_configs/roberta-ViT-B-32.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", "proj": "mlp", "pooler_type": "mean_pooler" } } ================================================ FILE: open_clip/model_configs/swin_base_patch4_window7_224.json ================================================ { "embed_dim": 640, "vision_cfg": { "timm_model_name": "swin_base_patch4_window7_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: open_clip/model_configs/vit_medium_patch16_gap_256.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "vit_medium_patch16_gap_256", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "vit_relpos_medium_patch16_cls_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: open_clip/model_configs/xlm-roberta-base-ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "xlm-roberta-base", "hf_tokenizer_name": "xlm-roberta-base", "proj": "mlp", "pooler_type": "mean_pooler" } } ================================================ FILE: open_clip/model_configs/xlm-roberta-large-ViT-H-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "hf_model_name": "xlm-roberta-large", "hf_tokenizer_name": "xlm-roberta-large", "proj": "mlp", "pooler_type": "mean_pooler" } } ================================================ FILE: open_clip/modified_resnet.py ================================================ from collections import OrderedDict import torch from torch import nn from torch.nn import functional as F from open_clip.utils import freeze_batch_norm_2d class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.act2 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.act3 = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential(OrderedDict([ ("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion)) ])) def forward(self, x: torch.Tensor): identity = x out = self.act1(self.bn1(self.conv1(x))) out = self.act2(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.act3(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0., out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, image_size=224, width=64): super().__init__() self.output_dim = output_dim self.image_size = image_size # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.act2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.act3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) self.init_parameters() def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def init_parameters(self): if self.attnpool is not None: std = self.attnpool.c_proj.in_features ** -0.5 nn.init.normal_(self.attnpool.q_proj.weight, std=std) nn.init.normal_(self.attnpool.k_proj.weight, std=std) nn.init.normal_(self.attnpool.v_proj.weight, std=std) nn.init.normal_(self.attnpool.c_proj.weight, std=std) for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) def lock(self, unlocked_groups=0, freeze_bn_stats=False): assert unlocked_groups == 0, 'partial locking not currently supported for this model' for param in self.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): # FIXME support for non-transformer pass def stem(self, x): x = self.act1(self.bn1(self.conv1(x))) x = self.act2(self.bn2(self.conv2(x))) x = self.act3(self.bn3(self.conv3(x))) x = self.avgpool(x) return x def forward(self, x, out_blocks): x = self.stem(x) x_1 = self.layer1(x) x_2 = self.layer2(x_1) x_3 = self.layer3(x_2) x_4 = self.layer4(x_3) x = self.attnpool(x_4) out_tokens = [] x_blocks = [x_1, x_2, x_3, x_4] for i in out_blocks: out_tokens.append(x_blocks[i - 1]) return x, out_tokens ================================================ FILE: open_clip/openai.py ================================================ """ OpenAI pretrained model functions Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import os import warnings from typing import List, Optional, Union import torch from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url __all__ = ["list_openai_models", "load_openai_model"] def list_openai_models() -> List[str]: """Returns the names of available CLIP models""" return list_pretrained_models_by_tag('openai') def load_openai_model( name: str, precision: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, jit: bool = True, cache_dir: Optional[str] = None, ): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict precision: str Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model (default) or more hackable non-JIT model. cache_dir : Optional[str] The directory to cache the downloaded model weights Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if precision is None: precision = 'fp32' if device == 'cpu' else 'fp16' if get_pretrained_url(name, 'openai'): model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) elif os.path.isfile(name): model_path = name else: raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() state_dict = None except RuntimeError: # loading saved state dict if jit: warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") jit = False state_dict = torch.load(model_path, map_location="cpu") if not jit: # Build a non-jit model from the OpenAI jitted model state dict cast_dtype = get_cast_dtype(precision) try: model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) except KeyError: sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) if precision.startswith('amp') or precision == 'fp32': model.float() elif precision == 'bf16': convert_weights_to_lp(model, dtype=torch.bfloat16) return model # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] def patch_device(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("prim::Constant"): if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_device) patch_device(model.encode_image) patch_device(model.encode_text) # patch dtype to float32 (typically for CPU) if precision == 'fp32': float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() def patch_float(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [1, 2]: # dtype can be the second or third argument to aten::to() if inputs[i].node()["value"] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) patch_float(model.encode_image) patch_float(model.encode_text) model.float() # ensure image_size attr available at consistent location for both jit and non-jit model.visual.image_size = model.input_resolution.item() return model ================================================ FILE: open_clip/pretrained.py ================================================ import hashlib import os import urllib import warnings from functools import partial from typing import Dict, Union from tqdm import tqdm from .version import __version__ try: from huggingface_hub import hf_hub_download hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) _has_hf_hub = True except ImportError: hf_hub_download = None _has_hf_hub = False def _pcfg(url='', hf_hub='', mean=None, std=None): return dict( url=url, hf_hub=hf_hub, mean=mean, std=std, ) _RN50 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), cc12m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), ) _RN50_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), cc12m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), ) _RN101 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), ) _RN101_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), ) _RN50x4 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), ) _RN50x16 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), ) _RN50x64 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), ) _VITB32 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), laion2b_e16=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') ) _VITB32_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), ) _VITB16 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), # laion400m_32k=_pcfg( # url="", # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # laion400m_64k=_pcfg( # url="", # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), ) _VITL14 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ) _VITL14_336 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), ) _VITH14 = dict( laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), ) _VITg14 = dict( laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), ) _VITbigG14 = dict( laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), ) _robertaViTB32 = dict( laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), ) _xlmRobertaBaseViTB32 = dict( laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), ) _xlmRobertaLargeFrozenViTH14 = dict( frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), ) _convnext_base = dict( laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), ) _convnext_base_w = dict( laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), ) _convnext_base_w_320 = dict( laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), ) _convnext_large_d = dict( laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), ) _convnext_large_d_320 = dict( laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), ) _convnext_xxlarge = dict( laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), ) _coca_VITB32 = dict( laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') ) _coca_VITL14 = dict( laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') ) _PRETRAINED = { "RN50": _RN50, "RN50-quickgelu": _RN50_quickgelu, "RN101": _RN101, "RN101-quickgelu": _RN101_quickgelu, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, "ViT-B-32": _VITB32, "ViT-B-32-quickgelu": _VITB32_quickgelu, "ViT-B-16": _VITB16, "ViT-B-16-plus-240": _VITB16_PLUS_240, "ViT-L-14": _VITL14, "ViT-L-14-336": _VITL14_336, "ViT-H-14": _VITH14, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, "roberta-ViT-B-32": _robertaViTB32, "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, "convnext_base": _convnext_base, "convnext_base_w": _convnext_base_w, "convnext_base_w_320": _convnext_base_w_320, "convnext_large_d": _convnext_large_d, "convnext_large_d_320": _convnext_large_d_320, "convnext_xxlarge": _convnext_xxlarge, "coca_ViT-B-32": _coca_VITB32, "coca_ViT-L-14": _coca_VITL14, } def _clean_tag(tag: str): # normalize pretrained tags return tag.lower().replace('-', '_') def list_pretrained(as_str: bool = False): """ returns list of pretrained models Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True """ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] def list_pretrained_models_by_tag(tag: str): """ return all models having the specified pretrain tag """ models = [] tag = _clean_tag(tag) for k in _PRETRAINED.keys(): if tag in _PRETRAINED[k]: models.append(k) return models def list_pretrained_tags_by_model(model: str): """ return all pretrain tags for the specified model architecture """ tags = [] if model in _PRETRAINED: tags.extend(_PRETRAINED[model].keys()) return tags def is_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return False return _clean_tag(tag) in _PRETRAINED[model] def get_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return {} model_pretrained = _PRETRAINED[model] return model_pretrained.get(_clean_tag(tag), {}) def get_pretrained_url(model: str, tag: str): cfg = get_pretrained_cfg(model, _clean_tag(tag)) return cfg.get('url', '') def download_pretrained_from_url( url: str, cache_dir: Union[str, None] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") os.makedirs(cache_dir, exist_ok=True) filename = os.path.basename(url) if 'openaipublic' in url: expected_sha256 = url.split("/")[-2] elif 'mlfoundations' in url: expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] else: expected_sha256 = '' download_target = os.path.join(cache_dir, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if expected_sha256: if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") else: return download_target with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') return _has_hf_hub def download_pretrained_from_hf( model_id: str, filename: str = 'open_clip_pytorch_model.bin', revision=None, cache_dir: Union[str, None] = None, ): has_hf_hub(True) cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) return cached_file def download_pretrained( cfg: Dict, force_hf_hub: bool = False, cache_dir: Union[str, None] = None, ): target = '' if not cfg: return target download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') if download_hf_hub and force_hf_hub: # use HF hub even if url exists download_url = '' if download_url: target = download_pretrained_from_url(download_url, cache_dir=cache_dir) elif download_hf_hub: has_hf_hub(True) # we assume the hf_hub entries in pretrained config combine model_id + filename in # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. model_id, filename = os.path.split(download_hf_hub) if filename: target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) else: target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) return target ================================================ FILE: open_clip/push_to_hf_hub.py ================================================ import argparse import json from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple import torch try: from huggingface_hub import ( create_repo, get_hf_file_metadata, hf_hub_download, hf_hub_url, repo_type_and_id_from_hf_id, upload_folder, ) from huggingface_hub.utils import EntryNotFoundError _has_hf_hub = True except ImportError: _has_hf_hub = False from .factory import create_model_from_pretrained, get_model_config, get_tokenizer from .tokenizer import HFTokenizer def save_config_for_hf( model, config_path: str, model_config: Optional[dict] ): preprocess_cfg = { 'mean': model.visual.image_mean, 'std': model.visual.image_std, } hf_config = { 'model_cfg': model_config, 'preprocess_cfg': preprocess_cfg, } with config_path.open('w') as f: json.dump(hf_config, f, indent=2) def save_for_hf( model, tokenizer: HFTokenizer, model_config: dict, save_directory: str, weights_filename='open_clip_pytorch_model.bin', config_filename='open_clip_config.json', ): save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True, parents=True) weights_path = save_directory / weights_filename torch.save(model.state_dict(), weights_path) tokenizer.save_pretrained(save_directory) config_path = save_directory / config_filename save_config_for_hf(model, config_path, model_config=model_config) def push_to_hf_hub( model, tokenizer, model_config: Optional[dict], repo_id: str, commit_message: str = 'Add model', token: Optional[str] = None, revision: Optional[str] = None, private: bool = False, create_pr: bool = False, model_card: Optional[dict] = None, ): if not isinstance(tokenizer, HFTokenizer): # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 tokenizer = HFTokenizer('openai/clip-vit-large-patch14') # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) # Infer complete repo_id from repo_url # Can be different from the input `repo_id` if repo_owner was implicit _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) repo_id = f"{repo_owner}/{repo_name}" # Check if README file already exist in repo try: get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) has_readme = True except EntryNotFoundError: has_readme = False # Dump model and push to Hub with TemporaryDirectory() as tmpdir: # Save model weights and config. save_for_hf( model, tokenizer=tokenizer, model_config=model_config, save_directory=tmpdir, ) # Add readme if it does not exist if not has_readme: model_card = model_card or {} model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" readme_text = generate_readme(model_card, model_name) readme_path.write_text(readme_text) # Upload model and return return upload_folder( repo_id=repo_id, folder_path=tmpdir, revision=revision, create_pr=create_pr, commit_message=commit_message, ) def push_pretrained_to_hf_hub( model_name, pretrained: str, repo_id: str, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, commit_message: str = 'Add model', token: Optional[str] = None, revision: Optional[str] = None, private: bool = False, create_pr: bool = False, model_card: Optional[dict] = None, ): model, preprocess_eval = create_model_from_pretrained( model_name, pretrained=pretrained, image_mean=image_mean, image_std=image_std, ) model_config = get_model_config(model_name) assert model_config tokenizer = get_tokenizer(model_name) push_to_hf_hub( model=model, tokenizer=tokenizer, model_config=model_config, repo_id=repo_id, commit_message=commit_message, token=token, revision=revision, private=private, create_pr=create_pr, model_card=model_card, ) def generate_readme(model_card: dict, model_name: str): readme_text = "---\n" readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" readme_text += "library_tag: open_clip\n" readme_text += f"license: {model_card.get('license', 'mit')}\n" if 'details' in model_card and 'Dataset' in model_card['details']: readme_text += 'datasets:\n' readme_text += f"- {model_card['details']['Dataset'].lower()}\n" readme_text += "---\n" readme_text += f"# Model card for {model_name}\n" if 'description' in model_card: readme_text += f"\n{model_card['description']}\n" if 'details' in model_card: readme_text += f"\n## Model Details\n" for k, v in model_card['details'].items(): if isinstance(v, (list, tuple)): readme_text += f"- **{k}:**\n" for vi in v: readme_text += f" - {vi}\n" elif isinstance(v, dict): readme_text += f"- **{k}:**\n" for ki, vi in v.items(): readme_text += f" - {ki}: {vi}\n" else: readme_text += f"- **{k}:** {v}\n" if 'usage' in model_card: readme_text += f"\n## Model Usage\n" readme_text += model_card['usage'] readme_text += '\n' if 'comparison' in model_card: readme_text += f"\n## Model Comparison\n" readme_text += model_card['comparison'] readme_text += '\n' if 'citation' in model_card: readme_text += f"\n## Citation\n" if not isinstance(model_card['citation'], (list, tuple)): citations = [model_card['citation']] else: citations = model_card['citation'] for c in citations: readme_text += f"```bibtex\n{c}\n```\n" return readme_text if __name__ == "__main__": parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") parser.add_argument( "--model", type=str, help="Name of the model to use.", ) parser.add_argument( "--pretrained", type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.", ) parser.add_argument( "--repo-id", type=str, help="Destination HF Hub repo-id ie 'organization/model_id'.", ) parser.add_argument( '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override default image mean value of dataset') parser.add_argument( '--image-std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of of dataset') args = parser.parse_args() print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') # FIXME add support to pass model_card json / template from file via cmd line push_pretrained_to_hf_hub( args.model, args.pretrained, args.repo_id, image_mean=args.image_mean, # override image mean/std if trained w/ non defaults image_std=args.image_std, ) print(f'{args.model} saved.') ================================================ FILE: open_clip/timm_model.py ================================================ """ timm model adapter Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. """ import logging from collections import OrderedDict import torch import torch.nn as nn try: import timm from timm.models.layers import Mlp, to_2tuple try: # old timm imports < 0.8.1 from timm.models.layers.attention_pool2d import RotAttentionPool2d from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d except ImportError: # new timm imports >= 0.8.1 from timm.layers import RotAttentionPool2d from timm.layers import AttentionPool2d as AbsAttentionPool2d except ImportError: timm = None from .utils import freeze_batch_norm_2d class TimmModel(nn.Module): """ timm model adapter # FIXME this adapter is a work in progress, may change in ways that break weight compat """ def __init__( self, model_name, embed_dim, image_size=224, pool='avg', proj='linear', proj_bias=False, drop=0., drop_path=None, pretrained=False, ): super().__init__() if timm is None: raise RuntimeError("Please `pip install timm` to use timm models.") self.image_size = to_2tuple(image_size) timm_kwargs = {} if drop_path is not None: timm_kwargs['drop_path_rate'] = drop_path self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) feat_size = self.trunk.default_cfg.get('pool_size', None) feature_ndim = 1 if not feat_size else 2 if pool in ('abs_attn', 'rot_attn'): assert feature_ndim == 2 # if attn pooling used, remove both classifier and default pool self.trunk.reset_classifier(0, global_pool='') else: # reset global pool if pool config set, otherwise leave as network default reset_kwargs = dict(global_pool=pool) if pool else {} self.trunk.reset_classifier(0, **reset_kwargs) prev_chs = self.trunk.num_features head_layers = OrderedDict() if pool == 'abs_attn': head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) prev_chs = embed_dim elif pool == 'rot_attn': head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim else: assert proj, 'projection layer needed if non-attention pooling is used.' # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used if proj == 'linear': head_layers['drop'] = nn.Dropout(drop) head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) self.head = nn.Sequential(head_layers) def lock(self, unlocked_groups=0, freeze_bn_stats=False): """ lock modules Args: unlocked_groups (int): leave last n layer groups unlocked (default: 0) """ if not unlocked_groups: # lock full model for param in self.trunk.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self.trunk) else: # NOTE: partial freeze requires latest timm (master) branch and is subject to change try: # FIXME import here until API stable and in an official release from timm.models.helpers import group_parameters, group_modules except ImportError: raise RuntimeError( 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') matcher = self.trunk.group_matcher() gparams = group_parameters(self.trunk, matcher) max_layer_id = max(gparams.keys()) max_layer_id = max_layer_id - unlocked_groups for group_idx in range(max_layer_id + 1): group = gparams[group_idx] for param in group: self.trunk.get_parameter(param).requires_grad = False if freeze_bn_stats: gmodules = group_modules(self.trunk, matcher, reverse=True) gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} freeze_batch_norm_2d(self.trunk, gmodules) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): try: self.trunk.set_grad_checkpointing(enable) except Exception as e: logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') def forward(self, x): x = self.trunk(x) x = self.head(x) return x ================================================ FILE: open_clip/tokenizer.py ================================================ """ CLIP tokenizer Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import gzip import html import os from functools import lru_cache from typing import Union, List import ftfy import regex as re import torch # https://stackoverflow.com/q/62691279 import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @lru_cache() def default_bpe(): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) if not special_tokens: special_tokens = ['', ''] else: special_tokens = ['', ''] + special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {t:t for t in special_tokens} special = "|".join(special_tokens) self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) self.vocab_size = len(self.encoder) self.all_special_ids = [self.encoder[t] for t in special_tokens] def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + ( token[-1] + '',) pairs = get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text _tokenizer = SimpleTokenizer() def decode(output_ids: torch.Tensor): output_ids = output_ids.cpu().numpy() return _tokenizer.decode(output_ids) def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder[""] eot_token = _tokenizer.encoder[""] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token result[i, :len(tokens)] = torch.tensor(tokens) return result class HFTokenizer: """HuggingFace tokenizer wrapper""" def __init__(self, tokenizer_name: str): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] texts = [whitespace_clean(basic_clean(text)) for text in texts] input_ids = self.tokenizer( texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True, ).input_ids return input_ids ================================================ FILE: open_clip/transform.py ================================================ import warnings from dataclasses import dataclass, asdict from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torchvision.transforms.functional as F from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD @dataclass class AugmentationCfg: scale: Tuple[float, float] = (0.9, 1.0) ratio: Optional[Tuple[float, float]] = None color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None interpolation: Optional[str] = None re_prob: Optional[float] = None re_count: Optional[int] = None use_timm: bool = False class ResizeMaxSize(nn.Module): def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): super().__init__() if not isinstance(max_size, int): raise TypeError(f"Size should be int. Got {type(max_size)}") self.max_size = max_size self.interpolation = interpolation self.fn = min if fn == 'min' else min self.fill = fill def forward(self, img): if isinstance(img, torch.Tensor): height, width = img.shape[:2] else: width, height = img.size scale = self.max_size / float(max(height, width)) if scale != 1.0: new_size = tuple(round(dim * scale) for dim in (height, width)) img = F.resize(img, new_size, self.interpolation) pad_h = self.max_size - new_size[0] pad_w = self.max_size - new_size[1] img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) return img def _convert_to_rgb(image): return image.convert('RGB') def image_transform( image_size: int, is_train: bool, mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, resize_longest_max: bool = False, fill_color: int = 0, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): mean = (mean,) * 3 std = std or OPENAI_DATASET_STD if not isinstance(std, (list, tuple)): std = (std,) * 3 if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: # for square size, pass size as int so that Resize() uses aspect preserving shortest edge image_size = image_size[0] if isinstance(aug_cfg, dict): aug_cfg = AugmentationCfg(**aug_cfg) else: aug_cfg = aug_cfg or AugmentationCfg() normalize = Normalize(mean=mean, std=std) if is_train: aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} use_timm = aug_cfg_dict.pop('use_timm', False) if use_timm: from timm.data import create_transform # timm can still be optional if isinstance(image_size, (tuple, list)): assert len(image_size) >= 2 input_size = (3,) + image_size[-2:] else: input_size = (3, image_size, image_size) # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time aug_cfg_dict.setdefault('interpolation', 'random') aug_cfg_dict.setdefault('color_jitter', None) # disable by default train_transform = create_transform( input_size=input_size, is_training=True, hflip=0., mean=mean, std=std, re_mode='pixel', **aug_cfg_dict, ) else: train_transform = Compose([ RandomResizedCrop( image_size, scale=aug_cfg_dict.pop('scale'), interpolation=InterpolationMode.BICUBIC, ), _convert_to_rgb, ToTensor(), normalize, ]) if aug_cfg_dict: warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') return train_transform else: if resize_longest_max: transforms = [ ResizeMaxSize(image_size, fill=fill_color) ] else: transforms = [ Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), CenterCrop((image_size, image_size)), ] transforms.extend([ _convert_to_rgb, ToTensor(), normalize, ]) return Compose(transforms) ================================================ FILE: open_clip/transformer.py ================================================ from collections import OrderedDict import math from typing import Callable, Optional, Sequence, Tuple import torch from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint from .utils import to_2tuple import numpy as np class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm (with cast back to input dtype).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class QuickGELU(nn.Module): # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 """ def __init__(self, prob, exclude_first_token=True): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token def forward(self, x): if not self.training or self.prob == 0.: return x if self.exclude_first_token: cls_tokens, x = x[:, :1], x[:, 1:] else: cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) batch = x.size()[0] num_tokens = x.size()[1] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] keep_prob = 1 - self.prob num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] if self.exclude_first_token: x = torch.cat((cls_tokens, x), dim=1) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1. / 0.01), attn_drop=0., proj_drop=0. ): super().__init__() self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.logit_scale_max = logit_scale_max # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) else: self.logit_scale = None self.attn_drop = nn.Dropout(attn_drop) if self.scale_heads: self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.head_scale = None self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) def forward(self, x, attn_mask: Optional[torch.Tensor] = None): L, N, C = x.shape q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn.view(N, self.num_heads, L, L) * logit_scale attn = attn.view(-1, L, L) else: q = q * self.scale attn = torch.bmm(q, k.transpose(-1, -2)) if attn_mask is not None: if attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) if self.head_scale is not None: x = x.view(N, self.num_heads, L, C) * self.head_scale x = x.view(-1, L, C) x = x.transpose(0, 1).reshape(L, N, C) x = self.out_proj(x) x = self.out_drop(x) return x class AttentionalPooler(nn.Module): def __init__( self, d_model: int, context_dim: int, n_head: int = 8, n_queries: int = 256, norm_layer: Callable = LayerNorm ): super().__init__() self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) self.ln_q = norm_layer(d_model) self.ln_k = norm_layer(context_dim) def forward(self, x: torch.Tensor): x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND N = x.shape[1] q = self.ln_q(self.query) out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] return out.permute(1, 0, 2) # LND -> NLD def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, is_cross_attention: bool = False, idx: int = 12, ): super().__init__() self.idx = idx self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() if is_cross_attention: self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def attention( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask ) def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) x = q_x + self.ls_1(tmp) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x, attn class CustomResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, scale_cosine_attn: bool = False, scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, ): super().__init__() self.width = width self.layers = layers self.grad_checkpointing = False self.resblocks = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, idx=idx) for idx in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9], attn_mask: Optional[torch.Tensor] = None): idx = 0 out_attn = [] # out_tokens = x out_tokens = [] for r in self.resblocks: idx += 1 if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: if idx == 12: x, attn = r(x, attn_mask=attn_mask) out_attn.append(attn) else: x, attn_tmp = r(x, attn_mask=attn_mask) if idx in out_layers: out_tokens.append(x) # out_tokens = x return x, out_attn, out_tokens class VisionTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, ls_init_value: float = None, global_average_pool: bool = False, attentional_pool: bool = False, n_queries: int = 256, attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., input_patchnorm: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False ): super().__init__() self.output_tokens = output_tokens image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) self.output_dim = output_dim # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 self.input_patchnorm = input_patchnorm if input_patchnorm: patch_input_dim = patch_height * patch_width * 3 self.patchnorm_pre_ln = LayerNorm(patch_input_dim) self.conv1 = nn.Linear(patch_input_dim, width) else: self.patchnorm_pre_ln = nn.Identity() self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) # class embeddings and positional embeddings scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() self.ln_pre = norm_layer(width) self.transformer = Transformer( width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) self.global_average_pool = global_average_pool if attentional_pool: self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) else: self.attn_pool = None self.ln_post = norm_layer(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) self.init_parameters() def lock(self, unlocked_groups=0, freeze_bn_stats=False): for param in self.parameters(): param.requires_grad = False if unlocked_groups != 0: groups = [ [ self.conv1, self.class_embedding, self.positional_embedding, self.ln_pre, ], *self.transformer.resblocks[:-1], [ self.transformer.resblocks[-1], self.ln_post, ], self.proj, ] def _unlock(x): if isinstance(x, Sequence): for g in x: _unlock(g) else: if isinstance(x, torch.nn.Parameter): x.requires_grad = True else: for p in x.parameters(): p.requires_grad = True _unlock(groups[-unlocked_groups:]) def init_parameters(self): # FIXME OpenAI CLIP did not define an init for the VisualTransformer # TODO experiment if default PyTorch init, below, or alternate init is best. # nn.init.normal_(self.class_embedding, std=self.scale) # nn.init.normal_(self.positional_embedding, std=self.scale) # # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) # attn_std = self.transformer.width ** -0.5 # fc_std = (2 * self.transformer.width) ** -0.5 # for block in self.transformer.resblocks: # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) # # if self.text_projection is not None: # nn.init.normal_(self.text_projection, std=self.scale) pass @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_average_pool: return x.mean(dim=1), x else: return x[:, 0], x[:, 1:] def forward(self, x: torch.Tensor, out_layers: list): # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) x = x.permute(0, 2, 4, 1, 3, 5) x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) x = self.patchnorm_pre_ln(x) x = self.conv1(x) else: x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.patch_dropout(x) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x, attn, patch_tokens = self.transformer(x, out_layers) # attn = attn[0, 0, 1:].view(14, 14) # 49 B, C, L = attn[0].shape H = int(np.sqrt(L-1)) out_attn = torch.zeros([H, H]).to('cuda') for i in range(len(attn)): out_attn += attn[i][0, 0, 1:].view(H, H) x = x.permute(1, 0, 2) # LND -> NLD patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD # patch_tokens = patch_tokens.permute(1, 0, 2) # LND -> NLD if self.attn_pool is not None: x = self.attn_pool(x) x = self.ln_post(x) pooled, tokens = self._global_pool(x) else: pooled, tokens = self._global_pool(x) pooled = self.ln_post(pooled) # patch_pooled, patch_tokens = self._global_pool(patch_tokens) # tokens = self.ln_post(tokens) if self.proj is not None: pooled = pooled @ self.proj # patch_tokens = patch_tokens @ self.proj # 不知道能不能行 # tokens = tokens @ self.proj if self.output_tokens: return pooled, patch_tokens return pooled, patch_tokens class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, context_length: int = 77, vocab_size: int = 49408, width: int = 512, heads: int = 8, layers: int = 12, ls_init_value: float = None, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, embed_cls: bool = False, pad_id: int = 0, output_tokens: bool = False, ): super().__init__() self.output_tokens = output_tokens self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim self.heads = heads self.pad_id = pad_id self.text_projection = nn.Parameter(torch.empty(width, output_dim)) if embed_cls: self.cls_emb = nn.Parameter(torch.empty(width)) self.num_pos += 1 else: self.cls_emb = None self.token_embedding = nn.Embedding(vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) self.ln_final = norm_layer(width) self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.init_parameters() def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if self.cls_emb is not None: nn.init.normal_(self.cls_emb, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask def _repeat(self, t, N: int): return t.reshape(1, 1, -1).repeat(N, 1, 1) def forward(self, text): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] attn_mask = self.attn_mask if self.cls_emb is not None: seq_len += 1 x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) cls_mask = self.build_cls_mask(text, cast_dtype) attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask) x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) if self.cls_emb is not None: pooled, tokens = x[:, -1], x[:, :-1] pooled = self.ln_final(pooled) else: x = self.ln_final(x) pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x if self.text_projection is not None: pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens return pooled class MultimodalTransformer(Transformer): def __init__( self, width: int, layers: int, heads: int, context_length: int = 77, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_dim: int = 512, ): super().__init__( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) self.context_length = context_length self.cross_attn = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True, ) for _ in range(layers) ]) self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def init_parameters(self): proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) for block in self.transformer.cross_attn: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def forward(self, image_embs, text_embs): text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq image_embs = image_embs.permute(1, 0, 2) # NLD -> LND seq_len = text_embs.shape[0] for resblock, cross_attn in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) else: text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) if self.text_projection is not None: x = x @ self.text_projection return x @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable ================================================ FILE: open_clip/utils.py ================================================ from itertools import repeat import collections.abc from torch import nn as nn from torchvision.ops.misc import FrozenBatchNorm2d def freeze_batch_norm_2d(module, module_match={}, name=''): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and returned. Otherwise, the module is walked recursively and submodules are converted in place. Args: module (torch.nn.Module): Any PyTorch module. module_match (dict): Dictionary of full module names to freeze (all if empty) name (str): Full module name (prefix) Returns: torch.nn.Module: Resulting module Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module is_match = True if module_match: is_match = name in module_match if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): res = FrozenBatchNorm2d(module.num_features) res.num_features = module.num_features res.affine = module.affine if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for child_name, child in module.named_children(): full_child_name = '.'.join([name, child_name]) if name else child_name new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) return res # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = lambda n, x: _ntuple(n)(x) ================================================ FILE: open_clip/version.py ================================================ __version__ = '2.16.0' ================================================ FILE: prompt_ensemble.py ================================================ import os from typing import Union, List from pkg_resources import packaging import torch import numpy as np def encode_text_with_prompt_ensemble(model, objs, tokenizer, device): prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage'] prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage'] prompt_state = [prompt_normal, prompt_abnormal] prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.'] text_prompts = {} for obj in objs: text_features = [] for i in range(len(prompt_state)): prompted_state = [state.format(obj) for state in prompt_state[i]] prompted_sentence = [] for s in prompted_state: for template in prompt_templates: prompted_sentence.append(template.format(s)) prompted_sentence = tokenizer(prompted_sentence).to(device) class_embeddings = model.encode_text(prompted_sentence) class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() text_features.append(class_embedding) text_features = torch.stack(text_features, dim=1).to(device) text_prompts[obj] = text_features return text_prompts ================================================ FILE: requirements.txt ================================================ ftfy==6.1.1 horovod==0.28.1 huggingface_hub==0.13.4 numpy==1.21.6 opencv_python==4.6.0.66 pandas==1.3.5 Pillow==9.2.0 regex==2022.10.31 scikit_image==0.19.3 scikit_learn==1.0.2 setuptools==63.4.1 tabulate==0.9.0 timm==0.8.15.dev0 torch==1.12.1 torchvision==0.13.1 tqdm==4.64.1 transformers==4.15.0 ================================================ FILE: test.py ================================================ import os import cv2 import json import torch import random import logging import argparse import numpy as np from PIL import Image from skimage import measure from tabulate import tabulate import torch.nn.functional as F import torchvision.transforms as transforms from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise import open_clip from few_shot import memory from model import LinearLayer from dataset import VisaDataset, MVTecDataset from prompt_ensemble import encode_text_with_prompt_ensemble def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def normalize(pred, max_value=None, min_value=None): if max_value is None or min_value is None: return (pred - pred.min()) / (pred.max() - pred.min()) else: return (pred - min_value) / (max_value - min_value) def apply_ad_scoremap(image, scoremap, alpha=0.5): np_image = np.asarray(image, dtype=float) scoremap = (scoremap * 255).astype(np.uint8) scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3): # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py binary_amaps = np.zeros_like(amaps, dtype=bool) min_th, max_th = amaps.min(), amaps.max() delta = (max_th - min_th) / max_step pros, fprs, ths = [], [], [] for th in np.arange(min_th, max_th, delta): binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1 pro = [] for binary_amap, mask in zip(binary_amaps, masks): for region in measure.regionprops(measure.label(mask)): tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum() pro.append(tp_pixels / region.area) inverse_masks = 1 - masks fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() fpr = fp_pixels / inverse_masks.sum() pros.append(np.array(pro).mean()) fprs.append(fpr) ths.append(th) pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths) idxes = fprs < expect_fpr fprs = fprs[idxes] fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min()) pro_auc = auc(fprs, pros[idxes]) return pro_auc def test(args): img_size = args.image_size features_list = args.features_list few_shot_features = args.few_shot_features dataset_dir = args.data_path save_path = args.save_path dataset_name = args.dataset if not os.path.exists(save_path): os.makedirs(save_path) device = "cuda" if torch.cuda.is_available() else "cpu" txt_path = os.path.join(save_path, 'log.txt') # clip model, _, preprocess = open_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained) model.to(device) tokenizer = open_clip.get_tokenizer(args.model) # logger root_logger = logging.getLogger() for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) root_logger.setLevel(logging.WARNING) logger = logging.getLogger('test') formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') logger.setLevel(logging.INFO) file_handler = logging.FileHandler(txt_path, mode='w') file_handler.setFormatter(formatter) logger.addHandler(file_handler) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) # record parameters for arg in vars(args): if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'): continue logger.info(f'{arg}: {getattr(args, arg)}') # seg with open(args.config_path, 'r') as f: model_configs = json.load(f) linearlayer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'], len(features_list), args.model).to(device) checkpoint = torch.load(args.checkpoint_path) linearlayer.load_state_dict(checkpoint["trainable_linearlayer"]) # dataset transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.CenterCrop(img_size), transforms.ToTensor() ]) if dataset_name == 'mvtec': test_data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform, aug_rate=-1, mode='test') else: test_data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform, mode='test') test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) obj_list = test_data.get_cls_names() # few shot if args.mode == 'few_shot': mem_features = memory(args.model, model, obj_list, dataset_dir, save_path, preprocess, transform, args.k_shot, few_shot_features, dataset_name, device) # text prompt with torch.cuda.amp.autocast(), torch.no_grad(): text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device) results = {} results['cls_names'] = [] results['imgs_masks'] = [] results['anomaly_maps'] = [] results['gt_sp'] = [] results['pr_sp'] = [] for items in test_dataloader: image = items['img'].to(device) cls_name = items['cls_name'] results['cls_names'].append(cls_name[0]) gt_mask = items['img_mask'] gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 results['imgs_masks'].append(gt_mask) # px results['gt_sp'].append(items['anomaly'].item()) with torch.no_grad(), torch.cuda.amp.autocast(): image_features, patch_tokens = model.encode_image(image, features_list) image_features /= image_features.norm(dim=-1, keepdim=True) text_features = [] for cls in cls_name: text_features.append(text_prompts[cls]) text_features = torch.stack(text_features, dim=0) # sample text_probs = (100.0 * image_features @ text_features[0]).softmax(dim=-1) results['pr_sp'].append(text_probs[0][1].cpu().item()) # pixel patch_tokens = linearlayer(patch_tokens) anomaly_maps = [] for layer in range(len(patch_tokens)): patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True) anomaly_map = (100.0 * patch_tokens[layer] @ text_features) B, L, C = anomaly_map.shape H = int(np.sqrt(L)) anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H), size=img_size, mode='bilinear', align_corners=True) anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1, :, :] anomaly_maps.append(anomaly_map.cpu().numpy()) anomaly_map = np.sum(anomaly_maps, axis=0) # few shot if args.mode == 'few_shot': image_features, patch_tokens = model.encode_image(image, few_shot_features) anomaly_maps_few_shot = [] for idx, p in enumerate(patch_tokens): if 'ViT' in args.model: p = p[0, 1:, :] else: p = p[0].view(p.shape[1], -1).permute(1, 0).contiguous() cos = pairwise.cosine_similarity(mem_features[cls_name[0]][idx].cpu(), p.cpu()) height = int(np.sqrt(cos.shape[1])) anomaly_map_few_shot = np.min((1 - cos), 0).reshape(1, 1, height, height) anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot), size=img_size, mode='bilinear', align_corners=True) anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy()) anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0) anomaly_map = anomaly_map + anomaly_map_few_shot results['anomaly_maps'].append(anomaly_map) # visualization path = items['img_path'] cls = path[0].split('/')[-2] filename = path[0].split('/')[-1] vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB mask = normalize(anomaly_map[0]) vis = apply_ad_scoremap(vis, mask) vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR save_vis = os.path.join(save_path, 'imgs', cls_name[0], cls) if not os.path.exists(save_vis): os.makedirs(save_vis) cv2.imwrite(os.path.join(save_vis, filename), vis) # metrics table_ls = [] auroc_sp_ls = [] auroc_px_ls = [] f1_sp_ls = [] f1_px_ls = [] aupro_ls = [] ap_sp_ls = [] ap_px_ls = [] for obj in obj_list: table = [] gt_px = [] pr_px = [] gt_sp = [] pr_sp = [] pr_sp_tmp = [] table.append(obj) for idxes in range(len(results['cls_names'])): if results['cls_names'][idxes] == obj: gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy()) pr_px.append(results['anomaly_maps'][idxes]) pr_sp_tmp.append(np.max(results['anomaly_maps'][idxes])) gt_sp.append(results['gt_sp'][idxes]) pr_sp.append(results['pr_sp'][idxes]) gt_px = np.array(gt_px) gt_sp = np.array(gt_sp) pr_px = np.array(pr_px) pr_sp = np.array(pr_sp) if args.mode == 'few_shot': pr_sp_tmp = np.array(pr_sp_tmp) pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min()) pr_sp = 0.5 * (pr_sp + pr_sp_tmp) auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel()) auroc_sp = roc_auc_score(gt_sp, pr_sp) ap_sp = average_precision_score(gt_sp, pr_sp) ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel()) # f1_sp precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp) f1_scores = (2 * precisions * recalls) / (precisions + recalls) f1_sp = np.max(f1_scores[np.isfinite(f1_scores)]) # f1_px precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel()) f1_scores = (2 * precisions * recalls) / (precisions + recalls) f1_px = np.max(f1_scores[np.isfinite(f1_scores)]) # aupro if len(gt_px.shape) == 4: gt_px = gt_px.squeeze(1) if len(pr_px.shape) == 4: pr_px = pr_px.squeeze(1) aupro = cal_pro_score(gt_px, pr_px) table.append(str(np.round(auroc_px * 100, decimals=1))) table.append(str(np.round(f1_px * 100, decimals=1))) table.append(str(np.round(ap_px * 100, decimals=1))) table.append(str(np.round(aupro * 100, decimals=1))) table.append(str(np.round(auroc_sp * 100, decimals=1))) table.append(str(np.round(f1_sp * 100, decimals=1))) table.append(str(np.round(ap_sp * 100, decimals=1))) table_ls.append(table) auroc_sp_ls.append(auroc_sp) auroc_px_ls.append(auroc_px) f1_sp_ls.append(f1_sp) f1_px_ls.append(f1_px) aupro_ls.append(aupro) ap_sp_ls.append(ap_sp) ap_px_ls.append(ap_px) # logger table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)), str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)), str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)), str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))]) results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', 'auroc_sp', 'f1_sp', 'ap_sp'], tablefmt="pipe") logger.info("\n%s", results) if __name__ == '__main__': parser = argparse.ArgumentParser("VAND Challenge", add_help=True) # paths parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset") parser.add_argument("--save_path", type=str, default='./results/tiaoshi', help='path to save results') parser.add_argument("--checkpoint_path", type=str, default='./exps/vit_huge_14/model_epoch12.pth', help='path to save results') parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-B-16.json', help="model configs") # model parser.add_argument("--dataset", type=str, default='mvtec', help="test dataset") parser.add_argument("--model", type=str, default="ViT-B-16", help="model used") parser.add_argument("--pretrained", type=str, default="laion400m_e32", help="pretrained weight used") parser.add_argument("--features_list", type=int, nargs="+", default=[3, 6, 9], help="features used") parser.add_argument("--few_shot_features", type=int, nargs="+", default=[3, 6, 9], help="features used for few shot") parser.add_argument("--image_size", type=int, default=224, help="image size") parser.add_argument("--mode", type=str, default="zero_shot", help="zero shot or few shot") # few shot parser.add_argument("--k_shot", type=int, default=10, help="e.g., 10-shot, 5-shot, 1-shot") parser.add_argument("--seed", type=int, default=10, help="random seed") args = parser.parse_args() setup_seed(args.seed) test(args) ================================================ FILE: test_few_shot.sh ================================================ ### test on the VisA dataset python test.py --mode few_shot --dataset visa \ --data_path ./data/visa --save_path ./results/visa/few_shot/4shot/seed42 \ --config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/mvtec_pretrained.pth \ --model ViT-L-14-336 --features_list 6 12 18 24 --few_shot_features 6 12 18 24 \ --pretrained openai --image_size 518 --k_shot 4 --seed 42 ### test on the MVTec AD dataset python test.py --mode few_shot --dataset mvtec \ --data_path ./data/mvtec --save_path ./results/mvtec/few_shot/4shot/seed42 \ --config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/visa_pretrained.pth \ --model ViT-L-14-336 --features_list 6 12 18 24 --few_shot_features 6 12 18 24 \ --pretrained openai --image_size 518 --k_shot 4 --seed 42 ================================================ FILE: test_zero_shot.sh ================================================ ### test on the VisA dataset python test.py --mode zero_shot --dataset visa \ --data_path ./data/visa --save_path ./results/visa/zero_shot \ --config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/mvtec_pretrained.pth \ --model ViT-L-14-336 --features_list 6 12 18 24 --pretrained openai --image_size 518 ### test on the MVTec AD dataset python test.py --mode zero_shot --dataset mvtec \ --data_path ./data/mvtec --save_path ./results/mvtec/zero_shot \ --config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/visa_pretrained.pth \ --model ViT-L-14-336 --features_list 6 12 18 24 --pretrained openai --image_size 518 ================================================ FILE: train.py ================================================ import torch import torch.nn as nn import numpy as np import random import os import json import argparse from torch.utils.data import DataLoader from datetime import datetime from torch.nn import functional as F import torch.backends.cudnn as cudnn import torchvision.transforms as transforms import logging import open_clip from dataset import VisaDataset, MVTecDataset from model import LinearLayer from loss import FocalLoss, BinaryDiceLoss from prompt_ensemble import encode_text_with_prompt_ensemble def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train(args): # configs epochs = args.epoch learning_rate = args.learning_rate batch_size = args.batch_size image_size = args.image_size device = 'cuda' if torch.cuda.is_available() else 'cpu' save_path = args.save_path if not os.path.exists(save_path): os.makedirs(save_path) txt_path = os.path.join(save_path, 'log.txt') # log # model configs features_list = args.features_list with open(args.config_path, 'r') as f: model_configs = json.load(f) # clip model model, _, preprocess = open_clip.create_model_and_transforms(args.model, image_size, pretrained=args.pretrained) model.to(device) tokenizer = open_clip.get_tokenizer(args.model) # logger root_logger = logging.getLogger() for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) root_logger.setLevel(logging.WARNING) logger = logging.getLogger('train') formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') logger.setLevel(logging.INFO) file_handler = logging.FileHandler(txt_path, mode='w') file_handler.setFormatter(formatter) logger.addHandler(file_handler) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) # record parameters for arg in vars(args): logger.info(f'{arg}: {getattr(args, arg)}') # transforms transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.CenterCrop(image_size), transforms.ToTensor() ]) # datasets if args.dataset == 'mvtec': train_data = MVTecDataset(root=args.train_data_path, transform=preprocess, target_transform=transform, aug_rate=args.aug_rate) else: train_data = VisaDataset(root=args.train_data_path, transform=preprocess, target_transform=transform) train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) # linear layer trainable_layer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'], len(args.features_list), args.model).to(device) optimizer = torch.optim.Adam(list(trainable_layer.parameters()), lr=learning_rate, betas=(0.5, 0.999)) # losses loss_focal = FocalLoss() loss_dice = BinaryDiceLoss() # text prompt with torch.cuda.amp.autocast(), torch.no_grad(): obj_list = train_data.get_cls_names() text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device) for epoch in range(epochs): loss_list = [] idx = 0 for items in train_dataloader: idx += 1 image = items['img'].to(device) cls_name = items['cls_name'] with torch.cuda.amp.autocast(): with torch.no_grad(): image_features, patch_tokens = model.encode_image(image, features_list) text_features = [] for cls in cls_name: text_features.append(text_prompts[cls]) text_features = torch.stack(text_features, dim=0) # pixel level patch_tokens = trainable_layer(patch_tokens) anomaly_maps = [] for layer in range(len(patch_tokens)): patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True) anomaly_map = (100.0 * patch_tokens[layer] @ text_features) B, L, C = anomaly_map.shape H = int(np.sqrt(L)) anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H), size=image_size, mode='bilinear', align_corners=True) anomaly_map = torch.softmax(anomaly_map, dim=1) anomaly_maps.append(anomaly_map) # losses gt = items['img_mask'].squeeze().to(device) gt[gt > 0.5], gt[gt <= 0.5] = 1, 0 loss = 0 for num in range(len(anomaly_maps)): loss += loss_focal(anomaly_maps[num], gt) loss += loss_dice(anomaly_maps[num][:, 1, :, :], gt) optimizer.zero_grad() loss.backward() optimizer.step() loss_list.append(loss.item()) # logs if (epoch + 1) % args.print_freq == 0: logger.info('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, np.mean(loss_list))) # save model if (epoch + 1) % args.save_freq == 0: ckp_path = os.path.join(save_path, 'epoch_' + str(epoch + 1) + '.pth') torch.save({'trainable_linearlayer': trainable_layer.state_dict()}, ckp_path) if __name__ == '__main__': parser = argparse.ArgumentParser("VAND Challenge", add_help=True) # path parser.add_argument("--train_data_path", type=str, default="./data/visa", help="train dataset path") parser.add_argument("--save_path", type=str, default='./exps/vit_large_14_518', help='path to save results') parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-B-16.json', help="model configs") # model parser.add_argument("--dataset", type=str, default='mvtec', help="train dataset name") parser.add_argument("--model", type=str, default="ViT-B-16", help="model used") parser.add_argument("--pretrained", type=str, default="laion400m_e32", help="pretrained weight used") parser.add_argument("--features_list", type=int, nargs="+", default=[3, 6, 9], help="features used") # hyper-parameter parser.add_argument("--epoch", type=int, default=200, help="epochs") parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") parser.add_argument("--batch_size", type=int, default=16, help="batch size") parser.add_argument("--image_size", type=int, default=224, help="image size") parser.add_argument("--aug_rate", type=float, default=0.2, help="image size") parser.add_argument("--print_freq", type=int, default=30, help="print frequency") parser.add_argument("--save_freq", type=int, default=3, help="save frequency") args = parser.parse_args() setup_seed(111) train(args) ================================================ FILE: train.sh ================================================ ### train on the MVTec AD dataset python train.py --dataset mvtec --train_data_path ./data/mvtec \ --save_path ./exps/visa/vit_large_14_518 --config_path ./open_clip/model_configs/ViT-L-14-336.json --model ViT-L-14-336 \ --features_list 6 12 18 24 --pretrained openai --image_size 518 --batch_size 8 --aug_rate 0.2 --print_freq 1 \ --epoch 3 --save_freq 1 ### train on the VisA dataset python train.py --dataset visa --train_data_path ./data/visa \ --save_path ./exps/mvtec/vit_large_14_518 --config_path ./open_clip/model_configs/ViT-L-14-336.json --model ViT-L-14-336 \ --features_list 6 12 18 24 --pretrained openai --image_size 518 --batch_size 8 --print_freq 1 \ --epoch 15 --save_freq 1