Repository: ByChelsea/VAND-APRIL-GAN
Branch: master
Commit: f13b8a634e04
Files: 90
Total size: 24.2 MB
Directory structure:
gitextract_7tq5j4qy/
├── LICENSE
├── README.md
├── data/
│ ├── mvtec.py
│ └── visa.py
├── dataset.py
├── exps/
│ └── pretrained/
│ ├── mvtec_pretrained.pth
│ └── visa_pretrained.pth
├── few_shot.py
├── loss.py
├── model.py
├── open_clip/
│ ├── __init__.py
│ ├── coca_model.py
│ ├── constants.py
│ ├── factory.py
│ ├── generation_utils.py
│ ├── hf_configs.py
│ ├── hf_model.py
│ ├── loss.py
│ ├── model.py
│ ├── model_configs/
│ │ ├── RN101-quickgelu.json
│ │ ├── RN101.json
│ │ ├── RN50-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── RN50x64.json
│ │ ├── ViT-B-16-plus-240.json
│ │ ├── ViT-B-16-plus.json
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32-plus-256.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-H-14.json
│ │ ├── ViT-H-16.json
│ │ ├── ViT-L-14-280.json
│ │ ├── ViT-L-14-336.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-L-16-320.json
│ │ ├── ViT-L-16.json
│ │ ├── ViT-M-16-alt.json
│ │ ├── ViT-M-16.json
│ │ ├── ViT-M-32-alt.json
│ │ ├── ViT-M-32.json
│ │ ├── ViT-S-16-alt.json
│ │ ├── ViT-S-16.json
│ │ ├── ViT-S-32-alt.json
│ │ ├── ViT-S-32.json
│ │ ├── ViT-bigG-14.json
│ │ ├── ViT-e-14.json
│ │ ├── ViT-g-14.json
│ │ ├── coca_ViT-B-32.json
│ │ ├── coca_ViT-L-14.json
│ │ ├── coca_base.json
│ │ ├── coca_roberta-ViT-B-32.json
│ │ ├── convnext_base.json
│ │ ├── convnext_base_w.json
│ │ ├── convnext_base_w_320.json
│ │ ├── convnext_large.json
│ │ ├── convnext_large_d.json
│ │ ├── convnext_large_d_320.json
│ │ ├── convnext_small.json
│ │ ├── convnext_tiny.json
│ │ ├── convnext_xlarge.json
│ │ ├── convnext_xxlarge.json
│ │ ├── convnext_xxlarge_320.json
│ │ ├── mt5-base-ViT-B-32.json
│ │ ├── mt5-xl-ViT-H-14.json
│ │ ├── roberta-ViT-B-32.json
│ │ ├── swin_base_patch4_window7_224.json
│ │ ├── vit_medium_patch16_gap_256.json
│ │ ├── vit_relpos_medium_patch16_cls_224.json
│ │ ├── xlm-roberta-base-ViT-B-32.json
│ │ └── xlm-roberta-large-ViT-H-14.json
│ ├── modified_resnet.py
│ ├── openai.py
│ ├── pretrained.py
│ ├── push_to_hf_hub.py
│ ├── timm_model.py
│ ├── tokenizer.py
│ ├── transform.py
│ ├── transformer.py
│ ├── utils.py
│ └── version.py
├── prompt_ensemble.py
├── requirements.txt
├── test.py
├── test_few_shot.sh
├── test_zero_shot.sh
├── train.py
└── train.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 Xuhai Chen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
[Workshop Link](https://sites.google.com/view/vand-cvpr23/home) | [Challenge Link](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0) | [Report Paper](https://arxiv.org/abs/2305.17382)
---
[Xuhai Chen](https://scholar.google.com.hk/citations?user=LU4etJ0AAAAJ&hl=zh-CN&authuser=1) · [Yue Han](https://scholar.google.com/citations?hl=en&user=08E500gAAAAJ&view_op=list_works&gmla=AHoSzlVzTXnclaPp9h1g8xAZQBrsxdFXvhunMA3AmRm_GSLnZA1956xavx6hmPaCFCysonsXeTQyhB_cokdUFacUc5HBunMPW-uOtLZLTTufiZiHB6hAVgr9l7cJ_UHKeQ) · [Jiangning Zhang](https://zhangzjn.github.io/)
This repository contains the official PyTorch implementation of [Zero-/Few-shot Anomaly Classification and Segmentation Method](https://arxiv.org/abs/2305.17382) used in the [CVPR 2023 VAND Challenge](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0), which can be viewd as an improved version of [WinCLIP](https://arxiv.org/abs/2303.14814). We achieve **Winner** in the Zero-shot Track and **Honorable Mentions** in the Few-shot Track.
**Results on the Challenge official test set**
## Installation
- Prepare experimental environments
```shell
pip install -r requirements.txt
```
## Dataset Preparation
### MVTec AD
- Download and extract [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) into `data/mvtec`
- run`python data/mvtec.py` to obtain `data/mvtec/meta.json`
```
data
├── mvtec
├── meta.json
├── bottle
├── train
├── good
├── 000.png
├── test
├── good
├── 000.png
├── anomaly1
├── 000.png
├── ground_truth
├── anomaly1
├── 000.png
```
### VisA
- Download and extract [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar) into `data/visa`
- run`python data/visa.py` to obtain `data/visa/meta.json`
```
data
├── visa
├── meta.json
├── candle
├── Data
├── Images
├── Anomaly
├── 000.JPG
├── Normal
├── 0000.JPG
├── Masks
├── Anomaly
├── 000.png
```
## Train
Set parameters in `train.sh`.
- `train_data_path`: the path to the training dataset
- `dataset`: name of the training dataset, optional: mvtec, visa
- `model`: the CLIP model
- `pretrained`: the pretrained weights
- `features_list`: features to be mapped into the joint embedding space
- `image_size`: the size of the images inputted into the CLIP model
- `aug_rate`: the probability of stitching images (only for mvtec)
Then run the following command
```shell
sh train.sh
```
## Test
### Pretrained Models
We provide our pre-trained models in `exps/pretrained`, where `mvtec_pretrained.pth` represents the model trained on the MVTec AD dataset and `visa_pretrained.pth` represents the model trained on the VisA dataset.
Set parameters in `test_zero_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
Then, run the following command to test them in the zero-shot setting:
```shell
sh test_zero_shot.sh
```
Set parameters in `test_few_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
- `k_shot`: different number of reference images
Then, run the following command to test them in the few-shot setting:
```shell
sh test_few_shot.sh
```
### Zero-shot Setting
Set parameters in `test_zero_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
- `model`: the CLIP model
- `pretrained`: the pretrained weights
- `features_list`: features to be mapped into the joint embedding space
- `image_size`: the size of the images inputted into the CLIP model
- `mode`: zero shot or few shot
Then run the following command
```shell
sh test_zero_shot.sh
```
### Few-shot Setting
Set parameters in `test_few_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
- `model`: the CLIP model
- `pretrained`: the pretrained weights
- `features_list`: features to be mapped into the joint embedding space
- `few_shot_features`: features stored in the memory banks
- `image_size`: the size of the images inputted into the CLIP model
- `mode`: zero shot or few shot
- `k_shot`: different number of reference images
- `seed`: the random seed
Then run the following command
```shell
sh test_few_shot.sh
```
## Citation
If our work is helpful for your research, please consider citing:
```
@article{chen2023zero,
title={A Zero-/Few-Shot Anomaly Classification and Segmentation Method for CVPR 2023 VAND Workshop Challenge Tracks 1\&2: 1st Place on Zero-shot AD and 4th Place on Few-shot AD},
author={Chen, Xuhai and Han, Yue and Zhang, Jiangning},
journal={arXiv preprint arXiv:2305.17382},
year={2023}
}
```
## Acknowledgements
We thank [WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation](https://arxiv.org/abs/2303.14814) for providing assistance for our research.
================================================
FILE: data/mvtec.py
================================================
import os
import json
class MVTecSolver(object):
CLSNAMES = [
'bottle', 'cable', 'capsule', 'carpet', 'grid',
'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
'tile', 'toothbrush', 'transistor', 'wood', 'zipper',
]
def __init__(self, root='data/mvtec'):
self.root = root
self.meta_path = f'{root}/meta.json'
def run(self):
info = dict(train={}, test={})
for cls_name in self.CLSNAMES:
cls_dir = f'{self.root}/{cls_name}'
for phase in ['train', 'test']:
cls_info = []
species = os.listdir(f'{cls_dir}/{phase}')
for specie in species:
is_abnormal = True if specie not in ['good'] else False
img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
img_names.sort()
mask_names.sort() if mask_names is not None else None
for idx, img_name in enumerate(img_names):
info_img = dict(
img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
cls_name=cls_name,
specie_name=specie,
anomaly=1 if is_abnormal else 0,
)
cls_info.append(info_img)
info[phase][cls_name] = cls_info
with open(self.meta_path, 'w') as f:
f.write(json.dumps(info, indent=4) + "\n")
if __name__ == '__main__':
runner = MVTecSolver(root='data/mvtec')
runner.run()
================================================
FILE: data/visa.py
================================================
import os
import json
import pandas as pd
class VisASolver(object):
CLSNAMES = [
'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
'pcb4', 'pipe_fryum',
]
def __init__(self, root='data/visa'):
self.root = root
self.meta_path = f'{root}/meta.json'
self.phases = ['train', 'test']
self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)
def run(self):
columns = self.csv_data.columns # [object, split, label, image, mask]
info = {phase: {} for phase in self.phases}
for cls_name in self.CLSNAMES:
cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]
for phase in self.phases:
cls_info = []
cls_data_phase = cls_data[cls_data[columns[1]] == phase]
cls_data_phase.index = list(range(len(cls_data_phase)))
for idx in range(cls_data_phase.shape[0]):
data = cls_data_phase.loc[idx]
is_abnormal = True if data[2] == 'anomaly' else False
info_img = dict(
img_path=data[3],
mask_path=data[4] if is_abnormal else '',
cls_name=cls_name,
specie_name='',
anomaly=1 if is_abnormal else 0,
)
cls_info.append(info_img)
info[phase][cls_name] = cls_info
with open(self.meta_path, 'w') as f:
f.write(json.dumps(info, indent=4) + "\n")
if __name__ == '__main__':
runner = VisASolver(root='data/visa')
runner.run()
================================================
FILE: dataset.py
================================================
import torch.utils.data as data
import json
import random
from PIL import Image
import numpy as np
import torch
import os
class VisaDataset(data.Dataset):
def __init__(self, root, transform, target_transform, mode='test', k_shot=0, save_dir=None, obj_name=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
if mode == 'train':
self.cls_names = [obj_name]
save_dir = os.path.join(save_dir, 'k_shot.txt')
else:
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
if mode == 'train':
data_tmp = meta_info[cls_name]
indices = torch.randint(0, len(data_tmp), (k_shot,))
for i in range(len(indices)):
self.data_all.append(data_tmp[indices[i]])
with open(save_dir, "a") as f:
f.write(data_tmp[indices[i]]['img_path'] + '\n')
else:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
def __len__(self):
return self.length
def get_cls_names(self):
return self.cls_names
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path)}
class MVTecDataset(data.Dataset):
def __init__(self, root, transform, target_transform, aug_rate, mode='test', k_shot=0, save_dir=None, obj_name=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.aug_rate = aug_rate
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
if mode == 'train':
self.cls_names = [obj_name]
save_dir = os.path.join(save_dir, 'k_shot.txt')
else:
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
if mode == 'train':
data_tmp = meta_info[cls_name]
indices = torch.randint(0, len(data_tmp), (k_shot,))
for i in range(len(indices)):
self.data_all.append(data_tmp[indices[i]])
with open(save_dir, "a") as f:
f.write(data_tmp[indices[i]]['img_path'] + '\n')
else:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
def __len__(self):
return self.length
def get_cls_names(self):
return self.cls_names
def combine_img(self, cls_name):
img_paths = os.path.join(self.root, cls_name, 'test')
img_ls = []
mask_ls = []
for i in range(4):
defect = os.listdir(img_paths)
random_defect = random.choice(defect)
files = os.listdir(os.path.join(img_paths, random_defect))
random_file = random.choice(files)
img_path = os.path.join(img_paths, random_defect, random_file)
mask_path = os.path.join(self.root, cls_name, 'ground_truth', random_defect, random_file[:3] + '_mask.png')
img = Image.open(img_path)
img_ls.append(img)
if random_defect == 'good':
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(mask_path).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
mask_ls.append(img_mask)
# image
image_width, image_height = img_ls[0].size
result_image = Image.new("RGB", (2 * image_width, 2 * image_height))
for i, img in enumerate(img_ls):
row = i // 2
col = i % 2
x = col * image_width
y = row * image_height
result_image.paste(img, (x, y))
# mask
result_mask = Image.new("L", (2 * image_width, 2 * image_height))
for i, img in enumerate(mask_ls):
row = i // 2
col = i % 2
x = col * image_width
y = row * image_height
result_mask.paste(img, (x, y))
return result_image, result_mask
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
random_number = random.random()
if random_number < self.aug_rate:
img, img_mask = self.combine_img(cls_name)
else:
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
# transforms
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path)}
================================================
FILE: exps/pretrained/mvtec_pretrained.pth
================================================
[File too large to display: 12.0 MB]
================================================
FILE: exps/pretrained/visa_pretrained.pth
================================================
[File too large to display: 12.0 MB]
================================================
FILE: few_shot.py
================================================
import torch
from dataset import VisaDataset, MVTecDataset
def memory(model_name, model, obj_list, dataset_dir, save_path, preprocess, transform, k_shot,
few_shot_features, dataset_name, device):
mem_features = {}
for obj in obj_list:
if dataset_name == 'mvtec':
data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform,
aug_rate=-1, mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj)
else:
data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform,
mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj)
dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
features = []
for items in dataloader:
image = items['img'].to(device)
with torch.no_grad():
image_features, patch_tokens = model.encode_image(image, few_shot_features)
if 'ViT' in model_name:
patch_tokens = [p[0, 1:, :] for p in patch_tokens]
else:
patch_tokens = [p[0].view(p.shape[1], -1).permute(1, 0).contiguous() for p in patch_tokens]
features.append(patch_tokens)
mem_features[obj] = [torch.cat(
[features[j][i] for j in range(len(features))], dim=0) for i in range(len(features[0]))]
return mem_features
================================================
FILE: loss.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import exp
class FocalLoss(nn.Module):
"""
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')
def forward(self, logit, target):
if self.apply_nonlin is not None:
logit = self.apply_nonlin(logit)
num_class = logit.shape[1]
if logit.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
alpha = self.alpha
if alpha is None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha
else:
raise TypeError('Not support alpha type')
if alpha.device != logit.device:
alpha = alpha.to(logit.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + self.smooth
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
return loss
class BinaryDiceLoss(nn.Module):
def __init__(self):
super(BinaryDiceLoss, self).__init__()
def forward(self, input, targets):
# 获取每个批次的大小 N
N = targets.size()[0]
# 平滑变量
smooth = 1
# 将宽高 reshape 到同一纬度
input_flat = input.view(N, -1)
targets_flat = targets.view(N, -1)
# 计算交集
intersection = input_flat * targets_flat
N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
# 计算一个批次中平均每张图的损失
loss = 1 - N_dice_eff.sum() / N
return loss
================================================
FILE: model.py
================================================
from torch import Tensor, nn
import torch
from torch.nn import functional as F
class LinearLayer(nn.Module):
def __init__(self, dim_in, dim_out, k, model):
super(LinearLayer, self).__init__()
if 'ViT' in model:
self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)])
else:
self.fc = nn.ModuleList([nn.Linear(dim_in * 2 ** (i + 2), dim_out) for i in range(k)])
def forward(self, tokens):
for i in range(len(tokens)):
if len(tokens[i].shape) == 3:
tokens[i] = self.fc[i](tokens[i][:, 1:, :])
else:
B, C, H, W = tokens[i].shape
tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous())
return tokens
================================================
FILE: open_clip/__init__.py
================================================
from .coca_model import CoCa
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
from .tokenizer import SimpleTokenizer, tokenize, decode
from .transform import image_transform, AugmentationCfg
================================================
FILE: open_clip/coca_model.py
================================================
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StoppingCriteriaList
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return decoder
class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = pad_id
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
# def _encode_image(self, images, out_layers, normalize=True):
# image_latent, tokens_embs = self.visual(images, out_layers)
# image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
# return image_latent, tokens_embs
def _encode_image(self, images, out_layers, normalize=True):
image_latent = self.visual(images, out_layers)
return image_latent
def _encode_text(self, text, normalize=True, embed_cls=True):
text = text[:, :-1] if embed_cls else text # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
# def encode_image(self, images, out_layers, normalize=True):
# image_latent, _ = self._encode_image(images, out_layers, normalize=normalize)
# return image_latent
def encode_image(self, images, out_layers, normalize=True):
image_latent = self._encode_image(images, out_layers, normalize=normalize)
return image_latent
def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
logits = self.text_decoder(image_embs, token_embs)
return {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"labels": labels,
"logit_scale": self.logit_scale.exp()
}
def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
with torch.no_grad():
sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(
stopping_criteria
)
device = image.device
if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs = image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
return torch.cat(
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
image_latent, image_embs = self._encode_image(image)
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
cur_len = text.shape[1]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if stopping_criteria(out, None):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
embed_cls=False,
image_latent=image_latent,
image_embs=image_embs
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
================================================
FILE: open_clip/constants.py
================================================
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
================================================
FILE: open_clip/factory.py
================================================
import json
import logging
import os
import pathlib
import re
import numpy as np
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def get_tokenizer(model_name):
if model_name.startswith(HF_HUB_PREFIX):
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
else:
config = get_model_config(model_name)
tokenizer = HFTokenizer(
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
return tokenizer
def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def create_model(
model_name: str,
img_size: int,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
pretrained_cfg = config['preprocess_cfg']
model_cfg = config['model_cfg']
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
pretrained_cfg = {}
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg['vision_cfg']['image_size'] != img_size:
model_cfg['vision_cfg']['image_size'] = img_size
cast_dtype = get_cast_dtype(precision)
model_pre = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
state_dict = model_pre.state_dict()
# to always output dict even if it is clip
if output_dict and hasattr(model_pre, "output_dict"):
model_pre.output_dict = True
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
### for resnet
if not hasattr(model.visual, 'grid_size'):
model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1))
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=True)
model.to(device=device)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
else:
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
else:
model_cfg = model_cfg or get_model_config(model_name)
model_cfg['vision_cfg']['image_size'] = img_size
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
pass
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
if pretrained_image:
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
model.to(device=device)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
return model
def create_loss(args):
if args.distill:
return DistillClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
def create_model_and_transforms(
model_name: str,
img_size: int,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
):
model = create_model(
model_name,
img_size,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
cache_dir=cache_dir,
require_pretrained=True,
)
if not return_transform:
return model
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
return model, preprocess
================================================
FILE: open_clip/generation_utils.py
================================================
================================================
FILE: open_clip/hf_configs.py
================================================
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
}
================================================
FILE: open_clip/hf_model.py
================================================
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(? torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
def forward(self, image_features, text_features, logit_scale, output_dict=False):
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return {"contrastive_loss": total_loss} if output_dict else total_loss
class CoCaLoss(ClipLoss):
def __init__(
self,
caption_loss_weight,
clip_loss_weight,
pad_id=0, # pad_token for open_clip custom tokenizer
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__(
local_loss=local_loss,
gather_with_grad=gather_with_grad,
cache_labels=cache_labels,
rank=rank,
world_size=world_size,
use_horovod=use_horovod
)
self.clip_loss_weight = clip_loss_weight
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
labels,
)
caption_loss = caption_loss * self.caption_loss_weight
if output_dict:
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
return clip_loss, caption_loss
class DistillClipLoss(ClipLoss):
def dist_loss(self, teacher_logits, student_logits):
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
def forward(
self,
image_features,
text_features,
logit_scale,
dist_image_features,
dist_text_features,
dist_logit_scale,
output_dict=False,
):
logits_per_image, logits_per_text = \
self.get_logits(image_features, text_features, logit_scale)
dist_logits_per_image, dist_logits_per_text = \
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
contrastive_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
distill_loss = (
self.dist_loss(dist_logits_per_image, logits_per_image) +
self.dist_loss(dist_logits_per_text, logits_per_text)
) / 2
if output_dict:
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
return contrastive_loss, distill_loss
================================================
FILE: open_clip/model.py
================================================
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
from dataclasses import dataclass
import logging
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
from .utils import to_2tuple
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
n_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
output_tokens: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
embed_cls: bool = False
pad_id: int = 0
output_tokens: bool = False
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
input_patchnorm=vision_cfg.input_patchnorm,
global_average_pool=vision_cfg.global_average_pool,
attentional_pool=vision_cfg.attentional_pool,
n_queries=vision_cfg.n_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
output_tokens=text_cfg.output_tokens,
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
def encode_image(self, image, out_layers, normalize: bool = False):
features = self.visual(image, out_layers)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
flag = 1
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None:
flag = 0
old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
if flag:
state_dict['visual.positional_embedding'] = new_pos_embed
else:
state_dict['visual.attnpool.positional_embedding'] = new_pos_embed
================================================
FILE: open_clip/model_configs/RN101-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
23,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/RN101.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
23,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/RN50-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/RN50.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/RN50x16.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 384,
"layers": [
6,
8,
18,
8
],
"width": 96,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/RN50x4.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 288,
"layers": [
4,
6,
10,
6
],
"width": 80,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/RN50x64.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 448,
"layers": [
3,
15,
36,
10
],
"width": 128,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-B-16-plus-240.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 240,
"layers": 12,
"width": 896,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-B-16-plus.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 896,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-B-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-B-32-plus-256.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 256,
"layers": 12,
"width": 896,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-B-32-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: open_clip/model_configs/ViT-H-16.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: open_clip/model_configs/ViT-L-14-280.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 280,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-L-14-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-L-16-320.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 320,
"layers": 24,
"width": 1024,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-L-16.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-M-16-alt.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16,
"ls_init_value": 1e-4
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-M-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-M-32-alt.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-M-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-S-16-alt.json
================================================
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
================================================
FILE: open_clip/model_configs/ViT-S-16.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-S-32-alt.json
================================================
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
================================================
FILE: open_clip/model_configs/ViT-S-32.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/ViT-bigG-14.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32
}
}
================================================
FILE: open_clip/model_configs/ViT-e-14.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 56,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.5715,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 36
}
}
================================================
FILE: open_clip/model_configs/ViT-g-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: open_clip/model_configs/coca_ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"attn_pooler_heads": 8
},
"custom_text": true
}
================================================
FILE: open_clip/model_configs/coca_ViT-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"attn_pooler_heads": 12
},
"custom_text": true
}
================================================
FILE: open_clip/model_configs/coca_base.json
================================================
{
"embed_dim": 512,
"multimodal_cfg": {
"width": 768,
"context_length": 76,
"vocab_size": 64000,
"mlp_ratio": 4,
"layers": 12,
"dim_head": 64,
"heads": 12,
"n_queries": 256,
"attn_pooler_heads": 8
},
"vision_cfg": {
"image_size": 288,
"layers": 12,
"width": 768,
"patch_size": 18,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 64000,
"layers": 12,
"heads": 12,
"width": 768,
"embed_cls": true,
"output_tokens": true
},
"custom_text": true
}
================================================
FILE: open_clip/model_configs/coca_roberta-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"output_tokens": true
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"proj": "linear",
"width": 768,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"width": 768,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: open_clip/model_configs/convnext_base.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/convnext_base_w.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/convnext_base_w_320.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/convnext_large.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/convnext_large_d.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
}
================================================
FILE: open_clip/model_configs/convnext_large_d_320.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
}
================================================
FILE: open_clip/model_configs/convnext_small.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "convnext_small",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/convnext_tiny.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_tiny",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/convnext_xlarge.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 20
}
}
================================================
FILE: open_clip/model_configs/convnext_xxlarge.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xxlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: open_clip/model_configs/convnext_xxlarge_320.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xxlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: open_clip/model_configs/mt5-base-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "google/mt5-base",
"hf_tokenizer_name": "google/mt5-base",
"proj": "mlp",
"pooler_type": "mean_pooler"
}
}
================================================
FILE: open_clip/model_configs/mt5-xl-ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "google/mt5-xl",
"hf_tokenizer_name": "google/mt5-xl",
"proj": "mlp",
"pooler_type": "mean_pooler"
}
}
================================================
FILE: open_clip/model_configs/roberta-ViT-B-32.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"proj": "mlp",
"pooler_type": "mean_pooler"
}
}
================================================
FILE: open_clip/model_configs/swin_base_patch4_window7_224.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "swin_base_patch4_window7_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/vit_medium_patch16_gap_256.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_medium_patch16_gap_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "xlm-roberta-base",
"hf_tokenizer_name": "xlm-roberta-base",
"proj": "mlp",
"pooler_type": "mean_pooler"
}
}
================================================
FILE: open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "xlm-roberta-large",
"hf_tokenizer_name": "xlm-roberta-large",
"proj": "mlp",
"pooler_type": "mean_pooler"
}
}
================================================
FILE: open_clip/modified_resnet.py
================================================
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from open_clip.utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x, out_blocks):
x = self.stem(x)
x_1 = self.layer1(x)
x_2 = self.layer2(x_1)
x_3 = self.layer3(x_2)
x_4 = self.layer4(x_3)
x = self.attnpool(x_4)
out_tokens = []
x_blocks = [x_1, x_2, x_3, x_4]
for i in out_blocks:
out_tokens.append(x_blocks[i - 1])
return x, out_tokens
================================================
FILE: open_clip/openai.py
================================================
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import List, Optional, Union
import torch
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag('openai')
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
jit: bool = True,
cache_dir: Optional[str] = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = 'fp32' if device == 'cpu' else 'fp16'
if get_pretrained_url(name, 'openai'):
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
try:
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model = model.to(device)
if precision.startswith('amp') or precision == 'fp32':
model.float()
elif precision == 'bf16':
convert_weights_to_lp(model, dtype=torch.bfloat16)
return model
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 (typically for CPU)
if precision == 'fp32':
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
# ensure image_size attr available at consistent location for both jit and non-jit
model.visual.image_size = model.input_resolution.item()
return model
================================================
FILE: open_clip/pretrained.py
================================================
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Union
from tqdm import tqdm
from .version import __version__
try:
from huggingface_hub import hf_hub_download
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url='', hf_hub='', mean=None, std=None):
return dict(
url=url,
hf_hub=hf_hub,
mean=mean,
std=std,
)
_RN50 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN50_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN101 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN101_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN50x4 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
)
_RN50x16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
)
_RN50x64 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
# laion400m_32k=_pcfg(
# url="",
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# laion400m_64k=_pcfg(
# url="",
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
laion2b_s32b_b82k=_pcfg(
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
)
_robertaViTB32 = dict(
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
)
_xlmRobertaBaseViTB32 = dict(
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
)
_xlmRobertaLargeFrozenViTH14 = dict(
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
)
_convnext_base = dict(
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
)
_convnext_base_w = dict(
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
)
_convnext_base_w_320 = dict(
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
)
_convnext_large_d = dict(
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
)
_convnext_large_d_320 = dict(
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
)
_convnext_xxlarge = dict(
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
)
_coca_VITB32 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
)
_coca_VITL14 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
)
_PRETRAINED = {
"RN50": _RN50,
"RN50-quickgelu": _RN50_quickgelu,
"RN101": _RN101,
"RN101-quickgelu": _RN101_quickgelu,
"RN50x4": _RN50x4,
"RN50x16": _RN50x16,
"RN50x64": _RN50x64,
"ViT-B-32": _VITB32,
"ViT-B-32-quickgelu": _VITB32_quickgelu,
"ViT-B-16": _VITB16,
"ViT-B-16-plus-240": _VITB16_PLUS_240,
"ViT-L-14": _VITL14,
"ViT-L-14-336": _VITL14_336,
"ViT-H-14": _VITH14,
"ViT-g-14": _VITg14,
"ViT-bigG-14": _VITbigG14,
"roberta-ViT-B-32": _robertaViTB32,
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
"convnext_base": _convnext_base,
"convnext_base_w": _convnext_base_w,
"convnext_base_w_320": _convnext_base_w_320,
"convnext_large_d": _convnext_large_d,
"convnext_large_d_320": _convnext_large_d_320,
"convnext_xxlarge": _convnext_xxlarge,
"coca_ViT-B-32": _coca_VITB32,
"coca_ViT-L-14": _coca_VITL14,
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace('-', '_')
def list_pretrained(as_str: bool = False):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
""" return all models having the specified pretrain tag """
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
""" return all pretrain tags for the specified model architecture """
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get('url', '')
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if 'openaipublic' in url:
expected_sha256 = url.split("/")[-2]
elif 'mlfoundations' in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ''
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = 'open_clip_pytorch_model.bin',
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ''
if not cfg:
return target
download_url = cfg.get('url', '')
download_hf_hub = cfg.get('hf_hub', '')
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ''
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
================================================
FILE: open_clip/push_to_hf_hub.py
================================================
import argparse
import json
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple
import torch
try:
from huggingface_hub import (
create_repo,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
repo_type_and_id_from_hf_id,
upload_folder,
)
from huggingface_hub.utils import EntryNotFoundError
_has_hf_hub = True
except ImportError:
_has_hf_hub = False
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
from .tokenizer import HFTokenizer
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict]
):
preprocess_cfg = {
'mean': model.visual.image_mean,
'std': model.visual.image_std,
}
hf_config = {
'model_cfg': model_config,
'preprocess_cfg': preprocess_cfg,
}
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def save_for_hf(
model,
tokenizer: HFTokenizer,
model_config: dict,
save_directory: str,
weights_filename='open_clip_pytorch_model.bin',
config_filename='open_clip_config.json',
):
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / weights_filename
torch.save(model.state_dict(), weights_path)
tokenizer.save_pretrained(save_directory)
config_path = save_directory / config_filename
save_config_for_hf(model, config_path, model_config=model_config)
def push_to_hf_hub(
model,
tokenizer,
model_config: Optional[dict],
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_card: Optional[dict] = None,
):
if not isinstance(tokenizer, HFTokenizer):
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"
# Check if README file already exist in repo
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(
model,
tokenizer=tokenizer,
model_config=model_config,
save_directory=tmpdir,
)
# Add readme if it does not exist
if not has_readme:
model_card = model_card or {}
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = generate_readme(model_card, model_name)
readme_path.write_text(readme_text)
# Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)
def push_pretrained_to_hf_hub(
model_name,
pretrained: str,
repo_id: str,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_card: Optional[dict] = None,
):
model, preprocess_eval = create_model_from_pretrained(
model_name,
pretrained=pretrained,
image_mean=image_mean,
image_std=image_std,
)
model_config = get_model_config(model_name)
assert model_config
tokenizer = get_tokenizer(model_name)
push_to_hf_hub(
model=model,
tokenizer=tokenizer,
model_config=model_config,
repo_id=repo_id,
commit_message=commit_message,
token=token,
revision=revision,
private=private,
create_pr=create_pr,
model_card=model_card,
)
def generate_readme(model_card: dict, model_name: str):
readme_text = "---\n"
readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
readme_text += "library_tag: open_clip\n"
readme_text += f"license: {model_card.get('license', 'mit')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
parser.add_argument(
"--model", type=str, help="Name of the model to use.",
)
parser.add_argument(
"--pretrained", type=str,
help="Use a pretrained CLIP model weights with the specified tag or file path.",
)
parser.add_argument(
"--repo-id", type=str,
help="Destination HF Hub repo-id ie 'organization/model_id'.",
)
parser.add_argument(
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override default image mean value of dataset')
parser.add_argument(
'--image-std', type=float, nargs='+', default=None, metavar='STD',
help='Override default image std deviation of of dataset')
args = parser.parse_args()
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
# FIXME add support to pass model_card json / template from file via cmd line
push_pretrained_to_hf_hub(
args.model,
args.pretrained,
args.repo_id,
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
image_std=args.image_std,
)
print(f'{args.model} saved.')
================================================
FILE: open_clip/timm_model.py
================================================
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
# FIXME this adapter is a work in progress, may change in ways that break weight compat
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
drop_path=None,
pretrained=False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if pool in ('abs_attn', 'rot_attn'):
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
else:
assert proj, 'projection layer needed if non-attention pooling is used.'
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x
================================================
FILE: open_clip/tokenizer.py
================================================
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import Union, List
import ftfy
import regex as re
import torch
# https://stackoverflow.com/q/62691279
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if not special_tokens:
special_tokens = ['', '']
else:
special_tokens = ['', ''] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t:t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '',)
pairs = get_pairs(word)
if not pairs:
return token+''
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
return text
_tokenizer = SimpleTokenizer()
def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder[""]
eot_token = _tokenizer.encoder[""]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
class HFTokenizer:
"""HuggingFace tokenizer wrapper"""
def __init__(self, tokenizer_name: str):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
).input_ids
return input_ids
================================================
FILE: open_clip/transform.py
================================================
import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
@dataclass
class AugmentationCfg:
scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
interpolation: Optional[str] = None
re_prob: Optional[float] = None
re_count: Optional[int] = None
use_timm: bool = False
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == 'min' else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
return img
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
if isinstance(aug_cfg, dict):
aug_cfg = AugmentationCfg(**aug_cfg)
else:
aug_cfg = aug_cfg or AugmentationCfg()
normalize = Normalize(mean=mean, std=std)
if is_train:
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
use_timm = aug_cfg_dict.pop('use_timm', False)
if use_timm:
from timm.data import create_transform # timm can still be optional
if isinstance(image_size, (tuple, list)):
assert len(image_size) >= 2
input_size = (3,) + image_size[-2:]
else:
input_size = (3, image_size, image_size)
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
aug_cfg_dict.setdefault('interpolation', 'random')
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
train_transform = create_transform(
input_size=input_size,
is_training=True,
hflip=0.,
mean=mean,
std=std,
re_mode='pixel',
**aug_cfg_dict,
)
else:
train_transform = Compose([
RandomResizedCrop(
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
ToTensor(),
normalize,
])
if aug_cfg_dict:
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
return train_transform
else:
if resize_longest_max:
transforms = [
ResizeMaxSize(image_size, fill=fill_color)
]
else:
transforms = [
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
CenterCrop((image_size, image_size)),
]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)
================================================
FILE: open_clip/transformer.py
================================================
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from .utils import to_2tuple
import numpy as np
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
logit_scale_max=math.log(1. / 0.01),
attn_drop=0.,
proj_drop=0.
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
return out.permute(1, 0, 2) # LND -> NLD
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
idx: int = 12,
):
super().__init__()
self.idx = idx
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(
q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask
)
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
x = q_x + self.ls_1(tmp)
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x, attn
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = Attention(
d_model, n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer,
idx=idx)
for idx in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9],
attn_mask: Optional[torch.Tensor] = None):
idx = 0
out_attn = []
# out_tokens = x
out_tokens = []
for r in self.resblocks:
idx += 1
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
if idx == 12:
x, attn = r(x, attn_mask=attn_mask)
out_attn.append(attn)
else:
x, attn_tmp = r(x, attn_mask=attn_mask)
if idx in out_layers:
out_tokens.append(x)
# out_tokens = x
return x, out_attn, out_tokens
class VisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
global_average_pool: bool = False,
attentional_pool: bool = False,
n_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
input_patchnorm: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False
):
super().__init__()
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
self.input_patchnorm = input_patchnorm
if input_patchnorm:
patch_input_dim = patch_height * patch_width * 3
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
self.conv1 = nn.Linear(patch_input_dim, width)
else:
self.patchnorm_pre_ln = nn.Identity()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size,
bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.global_average_pool = global_average_pool
if attentional_pool:
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
else:
self.attn_pool = None
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.init_parameters()
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def init_parameters(self):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_average_pool:
return x.mean(dim=1), x
else:
return x[:, 0], x[:, 1:]
def forward(self, x: torch.Tensor, out_layers: list):
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1],
self.patch_size[1])
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
x = self.patchnorm_pre_ln(x)
x = self.conv1(x)
else:
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x, attn, patch_tokens = self.transformer(x, out_layers)
# attn = attn[0, 0, 1:].view(14, 14) # 49
B, C, L = attn[0].shape
H = int(np.sqrt(L-1))
out_attn = torch.zeros([H, H]).to('cuda')
for i in range(len(attn)):
out_attn += attn[i][0, 0, 1:].view(H, H)
x = x.permute(1, 0, 2) # LND -> NLD
patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD
# patch_tokens = patch_tokens.permute(1, 0, 2) # LND -> NLD
if self.attn_pool is not None:
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
else:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
# patch_pooled, patch_tokens = self._global_pool(patch_tokens)
# tokens = self.ln_post(tokens)
if self.proj is not None:
pooled = pooled @ self.proj
# patch_tokens = patch_tokens @ self.proj # 不知道能不能行
# tokens = tokens @ self.proj
if self.output_tokens:
return pooled, patch_tokens
return pooled, patch_tokens
class TextTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
ls_init_value: float = None,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
embed_cls: bool = False,
pad_id: int = 0,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pad_id = pad_id
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if embed_cls:
self.cls_emb = nn.Parameter(torch.empty(width))
self.num_pos += 1
else:
self.cls_emb = None
self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.ln_final = norm_layer(width)
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if self.cls_emb is not None:
nn.init.normal_(self.cls_emb, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
def _repeat(self, t, N: int):
return t.reshape(1, 1, -1).repeat(N, 1, 1)
def forward(self, text):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if self.cls_emb is not None:
seq_len += 1
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
if self.cls_emb is not None:
pooled, tokens = x[:, -1], x[:, :-1]
pooled = self.ln_final(pooled)
else:
x = self.ln_final(x)
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
if self.text_projection is not None:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class MultimodalTransformer(Transformer):
def __init__(
self,
width: int,
layers: int,
heads: int,
context_length: int = 77,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_dim: int = 512,
):
super().__init__(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.context_length = context_length
self.cross_attn = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
is_cross_attention=True,
)
for _ in range(layers)
])
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def init_parameters(self):
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
for block in self.transformer.cross_attn:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, image_embs, text_embs):
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
seq_len = text_embs.shape[0]
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
x = text_embs.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
if self.text_projection is not None:
x = x @ self.text_projection
return x
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
================================================
FILE: open_clip/utils.py
================================================
from itertools import repeat
import collections.abc
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
================================================
FILE: open_clip/version.py
================================================
__version__ = '2.16.0'
================================================
FILE: prompt_ensemble.py
================================================
import os
from typing import Union, List
from pkg_resources import packaging
import torch
import numpy as np
def encode_text_with_prompt_ensemble(model, objs, tokenizer, device):
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
text_prompts = {}
for obj in objs:
text_features = []
for i in range(len(prompt_state)):
prompted_state = [state.format(obj) for state in prompt_state[i]]
prompted_sentence = []
for s in prompted_state:
for template in prompt_templates:
prompted_sentence.append(template.format(s))
prompted_sentence = tokenizer(prompted_sentence).to(device)
class_embeddings = model.encode_text(prompted_sentence)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1).to(device)
text_prompts[obj] = text_features
return text_prompts
================================================
FILE: requirements.txt
================================================
ftfy==6.1.1
horovod==0.28.1
huggingface_hub==0.13.4
numpy==1.21.6
opencv_python==4.6.0.66
pandas==1.3.5
Pillow==9.2.0
regex==2022.10.31
scikit_image==0.19.3
scikit_learn==1.0.2
setuptools==63.4.1
tabulate==0.9.0
timm==0.8.15.dev0
torch==1.12.1
torchvision==0.13.1
tqdm==4.64.1
transformers==4.15.0
================================================
FILE: test.py
================================================
import os
import cv2
import json
import torch
import random
import logging
import argparse
import numpy as np
from PIL import Image
from skimage import measure
from tabulate import tabulate
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
import open_clip
from few_shot import memory
from model import LinearLayer
from dataset import VisaDataset, MVTecDataset
from prompt_ensemble import encode_text_with_prompt_ensemble
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def normalize(pred, max_value=None, min_value=None):
if max_value is None or min_value is None:
return (pred - pred.min()) / (pred.max() - pred.min())
else:
return (pred - min_value) / (max_value - min_value)
def apply_ad_scoremap(image, scoremap, alpha=0.5):
np_image = np.asarray(image, dtype=float)
scoremap = (scoremap * 255).astype(np.uint8)
scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
# ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
binary_amaps = np.zeros_like(amaps, dtype=bool)
min_th, max_th = amaps.min(), amaps.max()
delta = (max_th - min_th) / max_step
pros, fprs, ths = [], [], []
for th in np.arange(min_th, max_th, delta):
binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
pro = []
for binary_amap, mask in zip(binary_amaps, masks):
for region in measure.regionprops(measure.label(mask)):
tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
pro.append(tp_pixels / region.area)
inverse_masks = 1 - masks
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
fpr = fp_pixels / inverse_masks.sum()
pros.append(np.array(pro).mean())
fprs.append(fpr)
ths.append(th)
pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
idxes = fprs < expect_fpr
fprs = fprs[idxes]
fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
pro_auc = auc(fprs, pros[idxes])
return pro_auc
def test(args):
img_size = args.image_size
features_list = args.features_list
few_shot_features = args.few_shot_features
dataset_dir = args.data_path
save_path = args.save_path
dataset_name = args.dataset
if not os.path.exists(save_path):
os.makedirs(save_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
txt_path = os.path.join(save_path, 'log.txt')
# clip
model, _, preprocess = open_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained)
model.to(device)
tokenizer = open_clip.get_tokenizer(args.model)
# logger
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
root_logger.setLevel(logging.WARNING)
logger = logging.getLogger('test')
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
datefmt='%y-%m-%d %H:%M:%S')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(txt_path, mode='w')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# record parameters
for arg in vars(args):
if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'):
continue
logger.info(f'{arg}: {getattr(args, arg)}')
# seg
with open(args.config_path, 'r') as f:
model_configs = json.load(f)
linearlayer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
len(features_list), args.model).to(device)
checkpoint = torch.load(args.checkpoint_path)
linearlayer.load_state_dict(checkpoint["trainable_linearlayer"])
# dataset
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.CenterCrop(img_size),
transforms.ToTensor()
])
if dataset_name == 'mvtec':
test_data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform,
aug_rate=-1, mode='test')
else:
test_data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform, mode='test')
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
obj_list = test_data.get_cls_names()
# few shot
if args.mode == 'few_shot':
mem_features = memory(args.model, model, obj_list, dataset_dir, save_path, preprocess, transform,
args.k_shot, few_shot_features, dataset_name, device)
# text prompt
with torch.cuda.amp.autocast(), torch.no_grad():
text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device)
results = {}
results['cls_names'] = []
results['imgs_masks'] = []
results['anomaly_maps'] = []
results['gt_sp'] = []
results['pr_sp'] = []
for items in test_dataloader:
image = items['img'].to(device)
cls_name = items['cls_name']
results['cls_names'].append(cls_name[0])
gt_mask = items['img_mask']
gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
results['imgs_masks'].append(gt_mask) # px
results['gt_sp'].append(items['anomaly'].item())
with torch.no_grad(), torch.cuda.amp.autocast():
image_features, patch_tokens = model.encode_image(image, features_list)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features = []
for cls in cls_name:
text_features.append(text_prompts[cls])
text_features = torch.stack(text_features, dim=0)
# sample
text_probs = (100.0 * image_features @ text_features[0]).softmax(dim=-1)
results['pr_sp'].append(text_probs[0][1].cpu().item())
# pixel
patch_tokens = linearlayer(patch_tokens)
anomaly_maps = []
for layer in range(len(patch_tokens)):
patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True)
anomaly_map = (100.0 * patch_tokens[layer] @ text_features)
B, L, C = anomaly_map.shape
H = int(np.sqrt(L))
anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
size=img_size, mode='bilinear', align_corners=True)
anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1, :, :]
anomaly_maps.append(anomaly_map.cpu().numpy())
anomaly_map = np.sum(anomaly_maps, axis=0)
# few shot
if args.mode == 'few_shot':
image_features, patch_tokens = model.encode_image(image, few_shot_features)
anomaly_maps_few_shot = []
for idx, p in enumerate(patch_tokens):
if 'ViT' in args.model:
p = p[0, 1:, :]
else:
p = p[0].view(p.shape[1], -1).permute(1, 0).contiguous()
cos = pairwise.cosine_similarity(mem_features[cls_name[0]][idx].cpu(), p.cpu())
height = int(np.sqrt(cos.shape[1]))
anomaly_map_few_shot = np.min((1 - cos), 0).reshape(1, 1, height, height)
anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
size=img_size, mode='bilinear', align_corners=True)
anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)
anomaly_map = anomaly_map + anomaly_map_few_shot
results['anomaly_maps'].append(anomaly_map)
# visualization
path = items['img_path']
cls = path[0].split('/')[-2]
filename = path[0].split('/')[-1]
vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB
mask = normalize(anomaly_map[0])
vis = apply_ad_scoremap(vis, mask)
vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR
save_vis = os.path.join(save_path, 'imgs', cls_name[0], cls)
if not os.path.exists(save_vis):
os.makedirs(save_vis)
cv2.imwrite(os.path.join(save_vis, filename), vis)
# metrics
table_ls = []
auroc_sp_ls = []
auroc_px_ls = []
f1_sp_ls = []
f1_px_ls = []
aupro_ls = []
ap_sp_ls = []
ap_px_ls = []
for obj in obj_list:
table = []
gt_px = []
pr_px = []
gt_sp = []
pr_sp = []
pr_sp_tmp = []
table.append(obj)
for idxes in range(len(results['cls_names'])):
if results['cls_names'][idxes] == obj:
gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy())
pr_px.append(results['anomaly_maps'][idxes])
pr_sp_tmp.append(np.max(results['anomaly_maps'][idxes]))
gt_sp.append(results['gt_sp'][idxes])
pr_sp.append(results['pr_sp'][idxes])
gt_px = np.array(gt_px)
gt_sp = np.array(gt_sp)
pr_px = np.array(pr_px)
pr_sp = np.array(pr_sp)
if args.mode == 'few_shot':
pr_sp_tmp = np.array(pr_sp_tmp)
pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min())
pr_sp = 0.5 * (pr_sp + pr_sp_tmp)
auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
auroc_sp = roc_auc_score(gt_sp, pr_sp)
ap_sp = average_precision_score(gt_sp, pr_sp)
ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
# f1_sp
precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])
# f1_px
precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel())
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
# aupro
if len(gt_px.shape) == 4:
gt_px = gt_px.squeeze(1)
if len(pr_px.shape) == 4:
pr_px = pr_px.squeeze(1)
aupro = cal_pro_score(gt_px, pr_px)
table.append(str(np.round(auroc_px * 100, decimals=1)))
table.append(str(np.round(f1_px * 100, decimals=1)))
table.append(str(np.round(ap_px * 100, decimals=1)))
table.append(str(np.round(aupro * 100, decimals=1)))
table.append(str(np.round(auroc_sp * 100, decimals=1)))
table.append(str(np.round(f1_sp * 100, decimals=1)))
table.append(str(np.round(ap_sp * 100, decimals=1)))
table_ls.append(table)
auroc_sp_ls.append(auroc_sp)
auroc_px_ls.append(auroc_px)
f1_sp_ls.append(f1_sp)
f1_px_ls.append(f1_px)
aupro_ls.append(aupro)
ap_sp_ls.append(ap_sp)
ap_px_ls.append(ap_px)
# logger
table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)),
str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)),
str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),
str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])
results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', 'auroc_sp',
'f1_sp', 'ap_sp'], tablefmt="pipe")
logger.info("\n%s", results)
if __name__ == '__main__':
parser = argparse.ArgumentParser("VAND Challenge", add_help=True)
# paths
parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset")
parser.add_argument("--save_path", type=str, default='./results/tiaoshi', help='path to save results')
parser.add_argument("--checkpoint_path", type=str, default='./exps/vit_huge_14/model_epoch12.pth', help='path to save results')
parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-B-16.json', help="model configs")
# model
parser.add_argument("--dataset", type=str, default='mvtec', help="test dataset")
parser.add_argument("--model", type=str, default="ViT-B-16", help="model used")
parser.add_argument("--pretrained", type=str, default="laion400m_e32", help="pretrained weight used")
parser.add_argument("--features_list", type=int, nargs="+", default=[3, 6, 9], help="features used")
parser.add_argument("--few_shot_features", type=int, nargs="+", default=[3, 6, 9], help="features used for few shot")
parser.add_argument("--image_size", type=int, default=224, help="image size")
parser.add_argument("--mode", type=str, default="zero_shot", help="zero shot or few shot")
# few shot
parser.add_argument("--k_shot", type=int, default=10, help="e.g., 10-shot, 5-shot, 1-shot")
parser.add_argument("--seed", type=int, default=10, help="random seed")
args = parser.parse_args()
setup_seed(args.seed)
test(args)
================================================
FILE: test_few_shot.sh
================================================
### test on the VisA dataset
python test.py --mode few_shot --dataset visa \
--data_path ./data/visa --save_path ./results/visa/few_shot/4shot/seed42 \
--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/mvtec_pretrained.pth \
--model ViT-L-14-336 --features_list 6 12 18 24 --few_shot_features 6 12 18 24 \
--pretrained openai --image_size 518 --k_shot 4 --seed 42
### test on the MVTec AD dataset
python test.py --mode few_shot --dataset mvtec \
--data_path ./data/mvtec --save_path ./results/mvtec/few_shot/4shot/seed42 \
--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/visa_pretrained.pth \
--model ViT-L-14-336 --features_list 6 12 18 24 --few_shot_features 6 12 18 24 \
--pretrained openai --image_size 518 --k_shot 4 --seed 42
================================================
FILE: test_zero_shot.sh
================================================
### test on the VisA dataset
python test.py --mode zero_shot --dataset visa \
--data_path ./data/visa --save_path ./results/visa/zero_shot \
--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/mvtec_pretrained.pth \
--model ViT-L-14-336 --features_list 6 12 18 24 --pretrained openai --image_size 518
### test on the MVTec AD dataset
python test.py --mode zero_shot --dataset mvtec \
--data_path ./data/mvtec --save_path ./results/mvtec/zero_shot \
--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/visa_pretrained.pth \
--model ViT-L-14-336 --features_list 6 12 18 24 --pretrained openai --image_size 518
================================================
FILE: train.py
================================================
import torch
import torch.nn as nn
import numpy as np
import random
import os
import json
import argparse
from torch.utils.data import DataLoader
from datetime import datetime
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import logging
import open_clip
from dataset import VisaDataset, MVTecDataset
from model import LinearLayer
from loss import FocalLoss, BinaryDiceLoss
from prompt_ensemble import encode_text_with_prompt_ensemble
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train(args):
# configs
epochs = args.epoch
learning_rate = args.learning_rate
batch_size = args.batch_size
image_size = args.image_size
device = 'cuda' if torch.cuda.is_available() else 'cpu'
save_path = args.save_path
if not os.path.exists(save_path):
os.makedirs(save_path)
txt_path = os.path.join(save_path, 'log.txt') # log
# model configs
features_list = args.features_list
with open(args.config_path, 'r') as f:
model_configs = json.load(f)
# clip model
model, _, preprocess = open_clip.create_model_and_transforms(args.model, image_size, pretrained=args.pretrained)
model.to(device)
tokenizer = open_clip.get_tokenizer(args.model)
# logger
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
root_logger.setLevel(logging.WARNING)
logger = logging.getLogger('train')
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
datefmt='%y-%m-%d %H:%M:%S')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(txt_path, mode='w')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# record parameters
for arg in vars(args):
logger.info(f'{arg}: {getattr(args, arg)}')
# transforms
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
# datasets
if args.dataset == 'mvtec':
train_data = MVTecDataset(root=args.train_data_path, transform=preprocess, target_transform=transform,
aug_rate=args.aug_rate)
else:
train_data = VisaDataset(root=args.train_data_path, transform=preprocess, target_transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
# linear layer
trainable_layer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
len(args.features_list), args.model).to(device)
optimizer = torch.optim.Adam(list(trainable_layer.parameters()), lr=learning_rate, betas=(0.5, 0.999))
# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
# text prompt
with torch.cuda.amp.autocast(), torch.no_grad():
obj_list = train_data.get_cls_names()
text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device)
for epoch in range(epochs):
loss_list = []
idx = 0
for items in train_dataloader:
idx += 1
image = items['img'].to(device)
cls_name = items['cls_name']
with torch.cuda.amp.autocast():
with torch.no_grad():
image_features, patch_tokens = model.encode_image(image, features_list)
text_features = []
for cls in cls_name:
text_features.append(text_prompts[cls])
text_features = torch.stack(text_features, dim=0)
# pixel level
patch_tokens = trainable_layer(patch_tokens)
anomaly_maps = []
for layer in range(len(patch_tokens)):
patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True)
anomaly_map = (100.0 * patch_tokens[layer] @ text_features)
B, L, C = anomaly_map.shape
H = int(np.sqrt(L))
anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
size=image_size, mode='bilinear', align_corners=True)
anomaly_map = torch.softmax(anomaly_map, dim=1)
anomaly_maps.append(anomaly_map)
# losses
gt = items['img_mask'].squeeze().to(device)
gt[gt > 0.5], gt[gt <= 0.5] = 1, 0
loss = 0
for num in range(len(anomaly_maps)):
loss += loss_focal(anomaly_maps[num], gt)
loss += loss_dice(anomaly_maps[num][:, 1, :, :], gt)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
# logs
if (epoch + 1) % args.print_freq == 0:
logger.info('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, np.mean(loss_list)))
# save model
if (epoch + 1) % args.save_freq == 0:
ckp_path = os.path.join(save_path, 'epoch_' + str(epoch + 1) + '.pth')
torch.save({'trainable_linearlayer': trainable_layer.state_dict()}, ckp_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser("VAND Challenge", add_help=True)
# path
parser.add_argument("--train_data_path", type=str, default="./data/visa", help="train dataset path")
parser.add_argument("--save_path", type=str, default='./exps/vit_large_14_518', help='path to save results')
parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-B-16.json', help="model configs")
# model
parser.add_argument("--dataset", type=str, default='mvtec', help="train dataset name")
parser.add_argument("--model", type=str, default="ViT-B-16", help="model used")
parser.add_argument("--pretrained", type=str, default="laion400m_e32", help="pretrained weight used")
parser.add_argument("--features_list", type=int, nargs="+", default=[3, 6, 9], help="features used")
# hyper-parameter
parser.add_argument("--epoch", type=int, default=200, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--image_size", type=int, default=224, help="image size")
parser.add_argument("--aug_rate", type=float, default=0.2, help="image size")
parser.add_argument("--print_freq", type=int, default=30, help="print frequency")
parser.add_argument("--save_freq", type=int, default=3, help="save frequency")
args = parser.parse_args()
setup_seed(111)
train(args)
================================================
FILE: train.sh
================================================
### train on the MVTec AD dataset
python train.py --dataset mvtec --train_data_path ./data/mvtec \
--save_path ./exps/visa/vit_large_14_518 --config_path ./open_clip/model_configs/ViT-L-14-336.json --model ViT-L-14-336 \
--features_list 6 12 18 24 --pretrained openai --image_size 518 --batch_size 8 --aug_rate 0.2 --print_freq 1 \
--epoch 3 --save_freq 1
### train on the VisA dataset
python train.py --dataset visa --train_data_path ./data/visa \
--save_path ./exps/mvtec/vit_large_14_518 --config_path ./open_clip/model_configs/ViT-L-14-336.json --model ViT-L-14-336 \
--features_list 6 12 18 24 --pretrained openai --image_size 518 --batch_size 8 --print_freq 1 \
--epoch 15 --save_freq 1