[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Xuhai Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "[Workshop Link](https://sites.google.com/view/vand-cvpr23/home) | [Challenge Link](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0) | [Report Paper](https://arxiv.org/abs/2305.17382)\r\n---\r\n[Xuhai Chen](https://scholar.google.com.hk/citations?user=LU4etJ0AAAAJ&hl=zh-CN&authuser=1) · [Yue Han](https://scholar.google.com/citations?hl=en&user=08E500gAAAAJ&view_op=list_works&gmla=AHoSzlVzTXnclaPp9h1g8xAZQBrsxdFXvhunMA3AmRm_GSLnZA1956xavx6hmPaCFCysonsXeTQyhB_cokdUFacUc5HBunMPW-uOtLZLTTufiZiHB6hAVgr9l7cJ_UHKeQ) · [Jiangning Zhang](https://zhangzjn.github.io/)\r\n\r\nThis repository contains the official PyTorch implementation of [Zero-/Few-shot Anomaly Classification and Segmentation Method](https://arxiv.org/abs/2305.17382) used in the [CVPR 2023 VAND Challenge](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0), which can be viewd as an improved version of [WinCLIP](https://arxiv.org/abs/2303.14814). We achieve **Winner** in the Zero-shot Track and **Honorable Mentions** in the Few-shot Track.\r\n\r\n<img src=\"illustration/main.png\" alt=\"Model Structure\" style=\"max-width: 50px; height: auto;\">\r\n\r\n**Results on the Challenge official test set**\r\n\r\n<img src=\"illustration/results.png\" alt=\"Model Structure\" style=\"max-width: 50px; height: auto;\">\r\n\r\n## Installation\r\n\r\n- Prepare experimental environments\r\n\r\n  ```shell\r\n  pip install -r requirements.txt\r\n  ```\r\n  \r\n## Dataset Preparation \r\n### MVTec AD\r\n- Download and extract [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) into `data/mvtec`\r\n- run`python data/mvtec.py` to obtain `data/mvtec/meta.json`\r\n```\r\ndata\r\n├── mvtec\r\n    ├── meta.json\r\n    ├── bottle\r\n        ├── train\r\n            ├── good\r\n                ├── 000.png\r\n        ├── test\r\n            ├── good\r\n                ├── 000.png\r\n            ├── anomaly1\r\n                ├── 000.png\r\n        ├── ground_truth\r\n            ├── anomaly1\r\n                ├── 000.png\r\n```\r\n\r\n### VisA\r\n- Download and extract [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar) into `data/visa`\r\n- run`python data/visa.py` to obtain `data/visa/meta.json`\r\n```\r\ndata\r\n├── visa\r\n    ├── meta.json\r\n    ├── candle\r\n        ├── Data\r\n            ├── Images\r\n                ├── Anomaly\r\n                    ├── 000.JPG\r\n                ├── Normal\r\n                    ├── 0000.JPG\r\n            ├── Masks\r\n                ├── Anomaly\r\n                    ├── 000.png\r\n```\r\n\r\n## Train\r\nSet parameters in `train.sh`.\r\n- `train_data_path`: the path to the training dataset\r\n- `dataset`: name of the training dataset, optional: mvtec, visa\r\n- `model`: the CLIP model\r\n- `pretrained`: the pretrained weights\r\n- `features_list`: features to be mapped into the joint embedding space\r\n- `image_size`: the size of the images inputted into the CLIP model\r\n- `aug_rate`: the probability of stitching images (only for mvtec)\r\n\r\nThen run the following command\r\n  ```shell\r\n  sh train.sh\r\n  ```\r\n\r\n## Test\r\n### Pretrained Models\r\nWe provide our pre-trained models in `exps/pretrained`, where `mvtec_pretrained.pth` represents the model trained on the MVTec AD dataset and `visa_pretrained.pth` represents the model trained on the VisA dataset.\r\n\r\nSet parameters in `test_zero_shot.sh`.\r\n- `data_path`: the path to the test dataset\r\n- `dataset`: name of the test dataset, optional: mvtec, visa\r\n- `checkpoint_path`: the path to the test model\r\n\r\nThen, run the following command to test them in the zero-shot setting:\r\n  ```shell\r\n  sh test_zero_shot.sh\r\n  ```\r\n  \r\nSet parameters in `test_few_shot.sh`.\r\n- `data_path`: the path to the test dataset\r\n- `dataset`: name of the test dataset, optional: mvtec, visa\r\n- `checkpoint_path`: the path to the test model\r\n- `k_shot`: different number of reference images\r\n\r\nThen, run the following command to test them in the few-shot setting:\r\n  ```shell\r\n  sh test_few_shot.sh\r\n  ```\r\n\r\n### Zero-shot Setting\r\nSet parameters in `test_zero_shot.sh`.\r\n- `data_path`: the path to the test dataset\r\n- `dataset`: name of the test dataset, optional: mvtec, visa\r\n- `checkpoint_path`: the path to the test model\r\n- `model`: the CLIP model\r\n- `pretrained`: the pretrained weights\r\n- `features_list`: features to be mapped into the joint embedding space\r\n- `image_size`: the size of the images inputted into the CLIP model\r\n- `mode`: zero shot or few shot\r\n\r\nThen run the following command\r\n  ```shell\r\n  sh test_zero_shot.sh\r\n  ```\r\n\r\n### Few-shot Setting\r\nSet parameters in `test_few_shot.sh`.\r\n- `data_path`: the path to the test dataset\r\n- `dataset`: name of the test dataset, optional: mvtec, visa\r\n- `checkpoint_path`: the path to the test model\r\n- `model`: the CLIP model\r\n- `pretrained`: the pretrained weights\r\n- `features_list`: features to be mapped into the joint embedding space\r\n- `few_shot_features`: features stored in the memory banks\r\n- `image_size`: the size of the images inputted into the CLIP model\r\n- `mode`: zero shot or few shot\r\n- `k_shot`: different number of reference images\r\n- `seed`: the random seed\r\n\r\nThen run the following command\r\n  ```shell\r\n  sh test_few_shot.sh\r\n  ```\r\n\r\n## Citation\r\nIf our work is helpful for your research, please consider citing:\r\n\r\n```\r\n@article{chen2023zero,\r\n  title={A Zero-/Few-Shot Anomaly Classification and Segmentation Method for CVPR 2023 VAND Workshop Challenge Tracks 1\\&2: 1st Place on Zero-shot AD and 4th Place on Few-shot AD},\r\n  author={Chen, Xuhai and Han, Yue and Zhang, Jiangning},\r\n  journal={arXiv preprint arXiv:2305.17382},\r\n  year={2023}\r\n}\r\n```\r\n\r\n## Acknowledgements\r\nWe thank [WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation](https://arxiv.org/abs/2303.14814) for providing assistance for our research.\r\n\r\n"
  },
  {
    "path": "data/mvtec.py",
    "content": "import os\r\nimport json\r\n\r\n\r\nclass MVTecSolver(object):\r\n    CLSNAMES = [\r\n        'bottle', 'cable', 'capsule', 'carpet', 'grid',\r\n        'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',\r\n        'tile', 'toothbrush', 'transistor', 'wood', 'zipper',\r\n    ]\r\n\r\n    def __init__(self, root='data/mvtec'):\r\n        self.root = root\r\n        self.meta_path = f'{root}/meta.json'\r\n\r\n    def run(self):\r\n        info = dict(train={}, test={})\r\n        for cls_name in self.CLSNAMES:\r\n            cls_dir = f'{self.root}/{cls_name}'\r\n            for phase in ['train', 'test']:\r\n                cls_info = []\r\n                species = os.listdir(f'{cls_dir}/{phase}')\r\n                for specie in species:\r\n                    is_abnormal = True if specie not in ['good'] else False\r\n                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')\r\n                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None\r\n                    img_names.sort()\r\n                    mask_names.sort() if mask_names is not None else None\r\n                    for idx, img_name in enumerate(img_names):\r\n                        info_img = dict(\r\n                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',\r\n                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',\r\n                            cls_name=cls_name,\r\n                            specie_name=specie,\r\n                            anomaly=1 if is_abnormal else 0,\r\n                        )\r\n                        cls_info.append(info_img)\r\n                info[phase][cls_name] = cls_info\r\n        with open(self.meta_path, 'w') as f:\r\n            f.write(json.dumps(info, indent=4) + \"\\n\")\r\n\r\nif __name__ == '__main__':\r\n    runner = MVTecSolver(root='data/mvtec')\r\n    runner.run()\r\n"
  },
  {
    "path": "data/visa.py",
    "content": "import os\r\nimport json\r\nimport pandas as pd\r\n\r\n\r\nclass VisASolver(object):\r\n    CLSNAMES = [\r\n        'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',\r\n        'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',\r\n        'pcb4', 'pipe_fryum',\r\n    ]\r\n\r\n    def __init__(self, root='data/visa'):\r\n        self.root = root\r\n        self.meta_path = f'{root}/meta.json'\r\n        self.phases = ['train', 'test']\r\n        self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)\r\n\r\n    def run(self):\r\n        columns = self.csv_data.columns  # [object, split, label, image, mask]\r\n        info = {phase: {} for phase in self.phases}\r\n        for cls_name in self.CLSNAMES:\r\n            cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]\r\n            for phase in self.phases:\r\n                cls_info = []\r\n                cls_data_phase = cls_data[cls_data[columns[1]] == phase]\r\n                cls_data_phase.index = list(range(len(cls_data_phase)))\r\n                for idx in range(cls_data_phase.shape[0]):\r\n                    data = cls_data_phase.loc[idx]\r\n                    is_abnormal = True if data[2] == 'anomaly' else False\r\n                    info_img = dict(\r\n                        img_path=data[3],\r\n                        mask_path=data[4] if is_abnormal else '',\r\n                        cls_name=cls_name,\r\n                        specie_name='',\r\n                        anomaly=1 if is_abnormal else 0,\r\n                    )\r\n                    cls_info.append(info_img)\r\n                info[phase][cls_name] = cls_info\r\n        with open(self.meta_path, 'w') as f:\r\n            f.write(json.dumps(info, indent=4) + \"\\n\")\r\n\r\n\r\nif __name__ == '__main__':\r\n    runner = VisASolver(root='data/visa')\r\n    runner.run()\r\n"
  },
  {
    "path": "dataset.py",
    "content": "import torch.utils.data as data\r\nimport json\r\nimport random\r\nfrom PIL import Image\r\nimport numpy as np\r\nimport torch\r\nimport os\r\n\r\n\r\nclass VisaDataset(data.Dataset):\r\n\tdef __init__(self, root, transform, target_transform, mode='test', k_shot=0, save_dir=None, obj_name=None):\r\n\t\tself.root = root\r\n\t\tself.transform = transform\r\n\t\tself.target_transform = target_transform\r\n\r\n\t\tself.data_all = []\r\n\t\tmeta_info = json.load(open(f'{self.root}/meta.json', 'r'))\r\n\t\tname = self.root.split('/')[-1]\r\n\t\tmeta_info = meta_info[mode]\r\n\r\n\t\tif mode == 'train':\r\n\t\t\tself.cls_names = [obj_name]\r\n\t\t\tsave_dir = os.path.join(save_dir, 'k_shot.txt')\r\n\t\telse:\r\n\t\t\tself.cls_names = list(meta_info.keys())\r\n\t\tfor cls_name in self.cls_names:\r\n\t\t\tif mode == 'train':\r\n\t\t\t\tdata_tmp = meta_info[cls_name]\r\n\t\t\t\tindices = torch.randint(0, len(data_tmp), (k_shot,))\r\n\t\t\t\tfor i in range(len(indices)):\r\n\t\t\t\t\tself.data_all.append(data_tmp[indices[i]])\r\n\t\t\t\t\twith open(save_dir, \"a\") as f:\r\n\t\t\t\t\t\tf.write(data_tmp[indices[i]]['img_path'] + '\\n')\r\n\t\t\telse:\r\n\t\t\t\tself.data_all.extend(meta_info[cls_name])\r\n\t\tself.length = len(self.data_all)\r\n\r\n\tdef __len__(self):\r\n\t\treturn self.length\r\n\r\n\tdef get_cls_names(self):\r\n\t\treturn self.cls_names\r\n\r\n\tdef __getitem__(self, index):\r\n\t\tdata = self.data_all[index]\r\n\t\timg_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \\\r\n\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t  data['specie_name'], data['anomaly']\r\n\t\timg = Image.open(os.path.join(self.root, img_path))\r\n\t\tif anomaly == 0:\r\n\t\t\timg_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')\r\n\t\telse:\r\n\t\t\timg_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0\r\n\t\t\timg_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')\r\n\t\timg = self.transform(img) if self.transform is not None else img\r\n\t\timg_mask = self.target_transform(\r\n\t\t\timg_mask) if self.target_transform is not None and img_mask is not None else img_mask\r\n\t\timg_mask = [] if img_mask is None else img_mask\r\n\r\n\t\treturn {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,\r\n\t\t\t\t'img_path': os.path.join(self.root, img_path)}\r\n\r\n\r\nclass MVTecDataset(data.Dataset):\r\n\tdef __init__(self, root, transform, target_transform, aug_rate, mode='test', k_shot=0, save_dir=None, obj_name=None):\r\n\t\tself.root = root\r\n\t\tself.transform = transform\r\n\t\tself.target_transform = target_transform\r\n\t\tself.aug_rate = aug_rate\r\n\r\n\t\tself.data_all = []\r\n\t\tmeta_info = json.load(open(f'{self.root}/meta.json', 'r'))\r\n\t\tname = self.root.split('/')[-1]\r\n\t\tmeta_info = meta_info[mode]\r\n\r\n\t\tif mode == 'train':\r\n\t\t\tself.cls_names = [obj_name]\r\n\t\t\tsave_dir = os.path.join(save_dir, 'k_shot.txt')\r\n\t\telse:\r\n\t\t\tself.cls_names = list(meta_info.keys())\r\n\t\tfor cls_name in self.cls_names:\r\n\t\t\tif mode == 'train':\r\n\t\t\t\tdata_tmp = meta_info[cls_name]\r\n\t\t\t\tindices = torch.randint(0, len(data_tmp), (k_shot,))\r\n\t\t\t\tfor i in range(len(indices)):\r\n\t\t\t\t\tself.data_all.append(data_tmp[indices[i]])\r\n\t\t\t\t\twith open(save_dir, \"a\") as f:\r\n\t\t\t\t\t\tf.write(data_tmp[indices[i]]['img_path'] + '\\n')\r\n\t\t\telse:\r\n\t\t\t\tself.data_all.extend(meta_info[cls_name])\r\n\t\tself.length = len(self.data_all)\r\n\r\n\tdef __len__(self):\r\n\t\treturn self.length\r\n\r\n\tdef get_cls_names(self):\r\n\t\treturn self.cls_names\r\n\r\n\tdef combine_img(self, cls_name):\r\n\t\timg_paths = os.path.join(self.root, cls_name, 'test')\r\n\t\timg_ls = []\r\n\t\tmask_ls = []\r\n\t\tfor i in range(4):\r\n\t\t\tdefect = os.listdir(img_paths)\r\n\t\t\trandom_defect = random.choice(defect)\r\n\t\t\tfiles = os.listdir(os.path.join(img_paths, random_defect))\r\n\t\t\trandom_file = random.choice(files)\r\n\t\t\timg_path = os.path.join(img_paths, random_defect, random_file)\r\n\t\t\tmask_path = os.path.join(self.root, cls_name, 'ground_truth', random_defect, random_file[:3] + '_mask.png')\r\n\t\t\timg = Image.open(img_path)\r\n\t\t\timg_ls.append(img)\r\n\t\t\tif random_defect == 'good':\r\n\t\t\t\timg_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')\r\n\t\t\telse:\r\n\t\t\t\timg_mask = np.array(Image.open(mask_path).convert('L')) > 0\r\n\t\t\t\timg_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')\r\n\t\t\tmask_ls.append(img_mask)\r\n\t\t# image\r\n\t\timage_width, image_height = img_ls[0].size\r\n\t\tresult_image = Image.new(\"RGB\", (2 * image_width, 2 * image_height))\r\n\t\tfor i, img in enumerate(img_ls):\r\n\t\t\trow = i // 2\r\n\t\t\tcol = i % 2\r\n\t\t\tx = col * image_width\r\n\t\t\ty = row * image_height\r\n\t\t\tresult_image.paste(img, (x, y))\r\n\r\n\t\t# mask\r\n\t\tresult_mask = Image.new(\"L\", (2 * image_width, 2 * image_height))\r\n\t\tfor i, img in enumerate(mask_ls):\r\n\t\t\trow = i // 2\r\n\t\t\tcol = i % 2\r\n\t\t\tx = col * image_width\r\n\t\t\ty = row * image_height\r\n\t\t\tresult_mask.paste(img, (x, y))\r\n\r\n\t\treturn result_image, result_mask\r\n\r\n\tdef __getitem__(self, index):\r\n\t\tdata = self.data_all[index]\r\n\t\timg_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \\\r\n\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t  data['specie_name'], data['anomaly']\r\n\t\trandom_number = random.random()\r\n\t\tif random_number < self.aug_rate:\r\n\t\t\timg, img_mask = self.combine_img(cls_name)\r\n\t\telse:\r\n\t\t\timg = Image.open(os.path.join(self.root, img_path))\r\n\t\t\tif anomaly == 0:\r\n\t\t\t\timg_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')\r\n\t\t\telse:\r\n\t\t\t\timg_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0\r\n\t\t\t\timg_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')\r\n\t\t# transforms\r\n\t\timg = self.transform(img) if self.transform is not None else img\r\n\t\timg_mask = self.target_transform(\r\n\t\t\timg_mask) if self.target_transform is not None and img_mask is not None else img_mask\r\n\t\timg_mask = [] if img_mask is None else img_mask\r\n\t\treturn {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,\r\n\t\t\t\t'img_path': os.path.join(self.root, img_path)}\r\n"
  },
  {
    "path": "few_shot.py",
    "content": "import torch\r\nfrom dataset import VisaDataset, MVTecDataset\r\n\r\ndef memory(model_name, model, obj_list, dataset_dir, save_path, preprocess, transform, k_shot,\r\n           few_shot_features, dataset_name, device):\r\n    mem_features = {}\r\n    for obj in obj_list:\r\n        if dataset_name == 'mvtec':\r\n            data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform,\r\n                                aug_rate=-1, mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj)\r\n        else:\r\n            data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform,\r\n                               mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj)\r\n        dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)\r\n        features = []\r\n        for items in dataloader:\r\n            image = items['img'].to(device)\r\n            with torch.no_grad():\r\n                image_features, patch_tokens = model.encode_image(image, few_shot_features)\r\n                if 'ViT' in model_name:\r\n                    patch_tokens = [p[0, 1:, :] for p in patch_tokens]\r\n                else:\r\n                    patch_tokens = [p[0].view(p.shape[1], -1).permute(1, 0).contiguous() for p in patch_tokens]\r\n                features.append(patch_tokens)\r\n        mem_features[obj] = [torch.cat(\r\n            [features[j][i] for j in range(len(features))], dim=0) for i in range(len(features[0]))]\r\n    return mem_features"
  },
  {
    "path": "loss.py",
    "content": "import numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom math import exp\r\n\r\nclass FocalLoss(nn.Module):\r\n    \"\"\"\r\n    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py\r\n    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in\r\n    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'\r\n        Focal_Loss= -1*alpha*(1-pt)*log(pt)\r\n    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion\r\n    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more\r\n                    focus on hard misclassified example\r\n    :param smooth: (float,double) smooth value when cross entropy\r\n    :param balance_index: (int) balance class index, should be specific when alpha is float\r\n    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.\r\n    \"\"\"\r\n\r\n    def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):\r\n        super(FocalLoss, self).__init__()\r\n        self.apply_nonlin = apply_nonlin\r\n        self.alpha = alpha\r\n        self.gamma = gamma\r\n        self.balance_index = balance_index\r\n        self.smooth = smooth\r\n        self.size_average = size_average\r\n\r\n        if self.smooth is not None:\r\n            if self.smooth < 0 or self.smooth > 1.0:\r\n                raise ValueError('smooth value should be in [0,1]')\r\n\r\n    def forward(self, logit, target):\r\n        if self.apply_nonlin is not None:\r\n            logit = self.apply_nonlin(logit)\r\n        num_class = logit.shape[1]\r\n\r\n        if logit.dim() > 2:\r\n            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)\r\n            logit = logit.view(logit.size(0), logit.size(1), -1)\r\n            logit = logit.permute(0, 2, 1).contiguous()\r\n            logit = logit.view(-1, logit.size(-1))\r\n        target = torch.squeeze(target, 1)\r\n        target = target.view(-1, 1)\r\n        alpha = self.alpha\r\n\r\n        if alpha is None:\r\n            alpha = torch.ones(num_class, 1)\r\n        elif isinstance(alpha, (list, np.ndarray)):\r\n            assert len(alpha) == num_class\r\n            alpha = torch.FloatTensor(alpha).view(num_class, 1)\r\n            alpha = alpha / alpha.sum()\r\n        elif isinstance(alpha, float):\r\n            alpha = torch.ones(num_class, 1)\r\n            alpha = alpha * (1 - self.alpha)\r\n            alpha[self.balance_index] = self.alpha\r\n\r\n        else:\r\n            raise TypeError('Not support alpha type')\r\n\r\n        if alpha.device != logit.device:\r\n            alpha = alpha.to(logit.device)\r\n\r\n        idx = target.cpu().long()\r\n\r\n        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()\r\n        one_hot_key = one_hot_key.scatter_(1, idx, 1)\r\n        if one_hot_key.device != logit.device:\r\n            one_hot_key = one_hot_key.to(logit.device)\r\n\r\n        if self.smooth:\r\n            one_hot_key = torch.clamp(\r\n                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)\r\n        pt = (one_hot_key * logit).sum(1) + self.smooth\r\n        logpt = pt.log()\r\n\r\n        gamma = self.gamma\r\n\r\n        alpha = alpha[idx]\r\n        alpha = torch.squeeze(alpha)\r\n        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt\r\n\r\n        if self.size_average:\r\n            loss = loss.mean()\r\n        return loss\r\n\r\n\r\nclass BinaryDiceLoss(nn.Module):\r\n    def __init__(self):\r\n        super(BinaryDiceLoss, self).__init__()\r\n\r\n    def forward(self, input, targets):\r\n        # 获取每个批次的大小 N\r\n        N = targets.size()[0]\r\n        # 平滑变量\r\n        smooth = 1\r\n        # 将宽高 reshape 到同一纬度\r\n        input_flat = input.view(N, -1)\r\n        targets_flat = targets.view(N, -1)\r\n\r\n        # 计算交集\r\n        intersection = input_flat * targets_flat\r\n        N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)\r\n        # 计算一个批次中平均每张图的损失\r\n        loss = 1 - N_dice_eff.sum() / N\r\n        return loss\r\n"
  },
  {
    "path": "model.py",
    "content": "from torch import Tensor, nn\r\nimport torch\r\nfrom torch.nn import functional as F\r\n\r\nclass LinearLayer(nn.Module):\r\n    def __init__(self, dim_in, dim_out, k, model):\r\n        super(LinearLayer, self).__init__()\r\n        if 'ViT' in model:\r\n            self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)])\r\n        else:\r\n            self.fc = nn.ModuleList([nn.Linear(dim_in * 2 ** (i + 2), dim_out) for i in range(k)])\r\n\r\n    def forward(self, tokens):\r\n        for i in range(len(tokens)):\r\n            if len(tokens[i].shape) == 3:\r\n                tokens[i] = self.fc[i](tokens[i][:, 1:, :])\r\n            else:\r\n                B, C, H, W = tokens[i].shape\r\n                tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous())\r\n        return tokens\r\n"
  },
  {
    "path": "open_clip/__init__.py",
    "content": "from .coca_model import CoCa\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\nfrom .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss\nfrom .factory import list_models, add_model_config, get_model_config, load_checkpoint\nfrom .loss import ClipLoss, DistillClipLoss, CoCaLoss\nfrom .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \\\n    convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype\nfrom .openai import load_openai_model, list_openai_models\nfrom .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \\\n    get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained\nfrom .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub\nfrom .tokenizer import SimpleTokenizer, tokenize, decode\nfrom .transform import image_transform, AugmentationCfg\n"
  },
  {
    "path": "open_clip/coca_model.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nfrom dataclasses import dataclass\n\nfrom .transformer import (\n    LayerNormFp32,\n    LayerNorm,\n    QuickGELU,\n    MultimodalTransformer,\n)\nfrom .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower\n\ntry:\n    from transformers import (\n        BeamSearchScorer,\n        LogitsProcessorList,\n        TopPLogitsWarper,\n        TopKLogitsWarper,\n        RepetitionPenaltyLogitsProcessor,\n        MinLengthLogitsProcessor,\n        MaxLengthCriteria,\n        StoppingCriteriaList\n    )\n\n    GENERATION_TYPES = {\n        \"top_k\": TopKLogitsWarper,\n        \"top_p\": TopPLogitsWarper,\n        \"beam_search\": \"beam_search\"\n    }\n    _has_transformers = True\nexcept ImportError as e:\n    GENERATION_TYPES = {\n        \"top_k\": None,\n        \"top_p\": None,\n        \"beam_search\": \"beam_search\"\n    }\n    _has_transformers = False\n\n\n@dataclass\nclass MultimodalCfg(CLIPTextCfg):\n    mlp_ratio: int = 4\n    dim_head: int = 64\n    heads: int = 8\n    n_queries: int = 256\n    attn_pooler_heads: int = 8\n\n\ndef _build_text_decoder_tower(\n        embed_dim,\n        multimodal_cfg,\n        quick_gelu: bool = False,\n        cast_dtype: Optional[torch.dtype] = None,\n):\n    multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n    act_layer = QuickGELU if quick_gelu else nn.GELU\n    norm_layer = (\n        LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n    )\n\n    decoder = MultimodalTransformer(\n        context_length=multimodal_cfg.context_length,\n        width=multimodal_cfg.width,\n        heads=multimodal_cfg.heads,\n        layers=multimodal_cfg.layers,\n        ls_init_value=multimodal_cfg.ls_init_value,\n        output_dim=embed_dim,\n        act_layer=act_layer,\n        norm_layer=norm_layer,\n    )\n\n    return decoder\n\n\nclass CoCa(nn.Module):\n    def __init__(\n            self,\n            embed_dim,\n            multimodal_cfg: MultimodalCfg,\n            text_cfg: CLIPTextCfg,\n            vision_cfg: CLIPVisionCfg,\n            quick_gelu: bool = False,\n            cast_dtype: Optional[torch.dtype] = None,\n            pad_id: int = 0,\n    ):\n        super().__init__()\n        multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n        text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg\n        vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg\n\n        self.text = _build_text_tower(\n            embed_dim=embed_dim,\n            text_cfg=text_cfg,\n            quick_gelu=quick_gelu,\n            cast_dtype=cast_dtype,\n        )\n\n        vocab_size = (\n            text_cfg.vocab_size  # for hf models\n            if hasattr(text_cfg, \"hf_model_name\") and text_cfg.hf_model_name is not None\n            else text_cfg.vocab_size\n        )\n\n        self.visual = _build_vision_tower(\n            embed_dim=embed_dim,\n            vision_cfg=vision_cfg,\n            quick_gelu=quick_gelu,\n            cast_dtype=cast_dtype,\n        )\n\n        self.text_decoder = _build_text_decoder_tower(\n            vocab_size,\n            multimodal_cfg=multimodal_cfg,\n            quick_gelu=quick_gelu,\n            cast_dtype=cast_dtype,\n        )\n\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n        self.pad_id = pad_id\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.visual.set_grad_checkpointing(enable)\n        self.text.set_grad_checkpointing(enable)\n        self.text_decoder.set_grad_checkpointing(enable)\n\n    # def _encode_image(self, images, out_layers, normalize=True):\n    #     image_latent, tokens_embs = self.visual(images, out_layers)\n    #     image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent\n    #     return image_latent, tokens_embs\n    def _encode_image(self, images, out_layers, normalize=True):\n        image_latent = self.visual(images, out_layers)\n        return image_latent\n\n    def _encode_text(self, text, normalize=True, embed_cls=True):\n        text = text[:, :-1] if embed_cls else text # make space for CLS token\n        text_latent, token_emb = self.text(text)\n        text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent\n        return text_latent, token_emb\n\n    # def encode_image(self, images, out_layers, normalize=True):\n    #     image_latent, _ = self._encode_image(images, out_layers, normalize=normalize)\n    #     return image_latent\n    def encode_image(self, images, out_layers, normalize=True):\n        image_latent = self._encode_image(images, out_layers, normalize=normalize)\n        return image_latent\n\n    def encode_text(self, text, normalize=True, embed_cls=True):\n        text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)\n        return text_latent\n\n    def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):\n        text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)\n        if image_latent is None or image_embs is None:\n            image_latent, image_embs = self._encode_image(image)\n\n        # TODO: add assertion to avoid bugs?\n        labels = text[:, -token_embs.shape[1]:]\n\n        logits = self.text_decoder(image_embs, token_embs)\n        return {\n            \"image_features\": image_latent,\n            \"text_features\": text_latent,\n            \"logits\": logits,\n            \"labels\": labels,\n            \"logit_scale\": self.logit_scale.exp()\n        }\n\n    def generate(\n        self,\n        image,\n        text=None,\n        seq_len=30,\n        max_seq_len=77,\n        temperature=1.,\n        generation_type=\"beam_search\",\n        top_p=0.1,  # keep tokens in the 1 - top_p quantile\n        top_k=1,  # keeps the top_k most probable tokens\n        pad_token_id=None,\n        eos_token_id=None,\n        sot_token_id=None,\n        num_beams=6,\n        num_beam_groups=3,\n        min_seq_len=5,\n        stopping_criteria=None,\n        repetition_penalty=1.0,\n        fixed_output_length=False # if True output.shape == (batch_size, seq_len)\n    ):\n        # taking many ideas and components from HuggingFace GenerationMixin\n        # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation\n        assert _has_transformers, \"Please install transformers for generate functionality. `pip install transformers`.\"\n        assert seq_len > min_seq_len, \"seq_len must be larger than min_seq_len\"\n\n        with torch.no_grad():\n            sot_token_id = 49406 if sot_token_id is None else sot_token_id\n            eos_token_id = 49407 if eos_token_id is None else eos_token_id\n            pad_token_id = self.pad_id if pad_token_id is None else pad_token_id\n            logit_processor = LogitsProcessorList(\n                [\n                    MinLengthLogitsProcessor(min_seq_len, eos_token_id),\n                    RepetitionPenaltyLogitsProcessor(repetition_penalty),\n                ]\n            )\n\n            if stopping_criteria is None:\n                stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]\n\n            stopping_criteria = StoppingCriteriaList(\n                stopping_criteria\n            )\n\n            device = image.device\n\n            if generation_type == \"beam_search\":\n                output = self._generate_beamsearch(\n                    image_inputs = image,\n                    pad_token_id=pad_token_id,\n                    eos_token_id=eos_token_id,\n                    sot_token_id=sot_token_id,\n                    num_beams=num_beams,\n                    num_beam_groups=num_beam_groups,\n                    min_seq_len=min_seq_len,\n                    stopping_criteria=stopping_criteria,\n                    logit_processor=logit_processor,\n                )\n                if fixed_output_length and output.shape[1] < seq_len:\n                    return torch.cat(\n                        (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),\n                        dim=1\n                    )\n                return output\n\n            elif generation_type == \"top_p\":\n                logit_warper = GENERATION_TYPES[generation_type](top_p)\n            elif generation_type == \"top_k\":\n                logit_warper = GENERATION_TYPES[generation_type](top_k)\n            else:\n                raise ValueError(\n                    f\"generation_type has to be one of \"\n                    f\"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}.\"\n                )\n\n            image_latent, image_embs = self._encode_image(image)\n\n            if text is None:\n                text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id\n\n            was_training = self.training\n            num_dims = len(text.shape)\n\n            if num_dims == 1:\n                text = text[None, :]\n\n            cur_len = text.shape[1]\n            self.eval()\n            out = text\n\n            while True:\n                x = out[:, -max_seq_len:]\n                cur_len = x.shape[1]\n                logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)[\"logits\"][:, -1]\n                mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)\n                sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id\n\n                if mask.all():\n                    if not fixed_output_length:\n                        break\n                else:\n                    logits = logits[~mask, :]\n                    filtered_logits = logit_processor(x[~mask, :], logits)\n                    filtered_logits = logit_warper(x[~mask, :], filtered_logits)\n                    probs = F.softmax(filtered_logits / temperature, dim=-1)\n\n                    if (cur_len + 1 == seq_len):\n                        sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id\n                    else:\n                        sample[~mask, :] = torch.multinomial(probs, 1)\n\n                out = torch.cat((out, sample), dim=-1)\n\n                cur_len += 1\n\n                if stopping_criteria(out, None):\n                    break\n\n            if num_dims == 1:\n                out = out.squeeze(0)\n\n            self.train(was_training)\n            return out\n\n    def _generate_beamsearch(\n            self,\n            image_inputs,\n            pad_token_id=None,\n            eos_token_id=None,\n            sot_token_id=None,\n            num_beams=6,\n            num_beam_groups=3,\n            min_seq_len=5,\n            stopping_criteria=None,\n            logit_processor=None,\n            logit_warper=None,\n    ):\n        device = image_inputs.device\n        batch_size = image_inputs.shape[0]\n        image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)\n        image_latent, image_embs = self._encode_image(image_inputs)\n\n        input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)\n        input_ids = input_ids * sot_token_id\n        beam_scorer = BeamSearchScorer(\n            batch_size=batch_size,\n            num_beams=num_beams,\n            device=device,\n            num_beam_groups=num_beam_groups,\n        )\n        # instantiate logits processors\n        logits_processor = (\n            LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])\n            if logit_processor is None\n            else logit_processor\n        )\n\n        batch_size = len(beam_scorer._beam_hyps)\n        num_beams = beam_scorer.num_beams\n        num_beam_groups = beam_scorer.num_beam_groups\n        num_sub_beams = num_beams // num_beam_groups\n        batch_beam_size, cur_len = input_ids.shape\n        beam_indices = None\n\n        if num_beams * batch_size != batch_beam_size:\n            raise ValueError(\n                f\"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}.\"\n            )\n\n        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)\n        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in\n        # the same group don't produce same tokens everytime.\n        beam_scores[:, ::num_sub_beams] = 0\n        beam_scores = beam_scores.view((batch_size * num_beams,))\n\n        while True:\n\n            # predicted tokens in cur_len step\n            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)\n\n            # indices which will form the beams in the next time step\n            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)\n\n            # do one decoder step on all beams of all sentences in batch\n            model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)\n            outputs = self(\n                model_inputs['images'],\n                model_inputs['text'],\n                embed_cls=False,\n                image_latent=image_latent,\n                image_embs=image_embs\n            )\n\n            for beam_group_idx in range(num_beam_groups):\n                group_start_idx = beam_group_idx * num_sub_beams\n                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)\n                group_size = group_end_idx - group_start_idx\n\n                # indices of beams of current group among all sentences in batch\n                batch_group_indices = []\n\n                for batch_idx in range(batch_size):\n                    batch_group_indices.extend(\n                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]\n                    )\n                group_input_ids = input_ids[batch_group_indices]\n\n                # select outputs of beams of currentg group only\n                next_token_logits = outputs['logits'][batch_group_indices, -1, :]\n                vocab_size = next_token_logits.shape[-1]\n\n                next_token_scores_processed = logits_processor(\n                    group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx\n                )\n                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)\n                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)\n\n                # reshape for beam search\n                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)\n\n                next_token_scores, next_tokens = torch.topk(\n                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True\n                )\n\n                next_indices = torch.div(next_tokens, vocab_size, rounding_mode=\"floor\")\n                next_tokens = next_tokens % vocab_size\n\n                # stateless\n                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n                beam_outputs = beam_scorer.process(\n                    group_input_ids,\n                    next_token_scores,\n                    next_tokens,\n                    next_indices,\n                    pad_token_id=pad_token_id,\n                    eos_token_id=eos_token_id,\n                    beam_indices=process_beam_indices,\n                )\n                beam_scores[batch_group_indices] = beam_outputs[\"next_beam_scores\"]\n                beam_next_tokens = beam_outputs[\"next_beam_tokens\"]\n                beam_idx = beam_outputs[\"next_beam_indices\"]\n\n                input_ids[batch_group_indices] = group_input_ids[beam_idx]\n                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)\n                current_tokens[batch_group_indices] = group_input_ids[:, -1]\n\n                # (beam_idx // group_size) -> batch_idx\n                # (beam_idx % group_size) -> offset of idx inside the group\n                reordering_indices[batch_group_indices] = (\n                    num_beams * torch.div(beam_idx, group_size, rounding_mode=\"floor\") + group_start_idx + (beam_idx % group_size)\n                )\n\n            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)\n\n            # increase cur_len\n            cur_len = cur_len + 1\n            if beam_scorer.is_done or stopping_criteria(input_ids, None):\n                break\n\n        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n        sequence_outputs = beam_scorer.finalize(\n            input_ids,\n            beam_scores,\n            next_tokens,\n            next_indices,\n            pad_token_id=pad_token_id,\n            eos_token_id=eos_token_id,\n            max_length=stopping_criteria.max_length,\n            beam_indices=final_beam_indices,\n        )\n        return sequence_outputs['sequences']\n\n\ndef prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):\n    if past:\n        input_ids = input_ids[:, -1].unsqueeze(-1)\n\n    attention_mask = kwargs.get(\"attention_mask\", None)\n    position_ids = kwargs.get(\"position_ids\", None)\n\n    if attention_mask is not None and position_ids is None:\n        # create position_ids on the fly for batch generation\n        position_ids = attention_mask.long().cumsum(-1) - 1\n        position_ids.masked_fill_(attention_mask == 0, 1)\n    else:\n        position_ids = None\n    return {\n        \"text\": input_ids,\n        \"images\": image_inputs,\n        \"past_key_values\": past,\n        \"position_ids\": position_ids,\n        \"attention_mask\": attention_mask,\n    }\n"
  },
  {
    "path": "open_clip/constants.py",
    "content": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\n"
  },
  {
    "path": "open_clip/factory.py",
    "content": "import json\nimport logging\nimport os\nimport pathlib\nimport re\nimport numpy as np\nfrom copy import deepcopy\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\n\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\nfrom .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\\\n    resize_pos_embed, get_cast_dtype\nfrom .coca_model import CoCa\nfrom .loss import ClipLoss, DistillClipLoss, CoCaLoss\nfrom .openai import load_openai_model\nfrom .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf\nfrom .transform import image_transform, AugmentationCfg\nfrom .tokenizer import HFTokenizer, tokenize\n\n\nHF_HUB_PREFIX = 'hf-hub:'\n_MODEL_CONFIG_PATHS = [Path(__file__).parent / f\"model_configs/\"]\n_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs\n\n\ndef _natural_key(string_):\n    return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_.lower())]\n\n\ndef _rescan_model_configs():\n    global _MODEL_CONFIGS\n\n    config_ext = ('.json',)\n    config_files = []\n    for config_path in _MODEL_CONFIG_PATHS:\n        if config_path.is_file() and config_path.suffix in config_ext:\n            config_files.append(config_path)\n        elif config_path.is_dir():\n            for ext in config_ext:\n                config_files.extend(config_path.glob(f'*{ext}'))\n\n    for cf in config_files:\n        with open(cf, 'r') as f:\n            model_cfg = json.load(f)\n            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):\n                _MODEL_CONFIGS[cf.stem] = model_cfg\n\n    _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}\n\n\n_rescan_model_configs()  # initial populate of model config registry\n\n\ndef list_models():\n    \"\"\" enumerate available model architectures based on config files \"\"\"\n    return list(_MODEL_CONFIGS.keys())\n\n\ndef add_model_config(path):\n    \"\"\" add model config path or file and update registry \"\"\"\n    if not isinstance(path, Path):\n        path = Path(path)\n    _MODEL_CONFIG_PATHS.append(path)\n    _rescan_model_configs()\n\n\ndef get_model_config(model_name):\n    if model_name in _MODEL_CONFIGS:\n        return deepcopy(_MODEL_CONFIGS[model_name])\n    else:\n        return None\n\n\ndef get_tokenizer(model_name):\n    if model_name.startswith(HF_HUB_PREFIX):\n        tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])\n    else:\n        config = get_model_config(model_name)\n        tokenizer = HFTokenizer(\n            config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize\n    return tokenizer\n\n\ndef load_state_dict(checkpoint_path: str, map_location='cpu'):\n    checkpoint = torch.load(checkpoint_path, map_location=map_location)\n    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:\n        state_dict = checkpoint['state_dict']\n    else:\n        state_dict = checkpoint\n    if next(iter(state_dict.items()))[0].startswith('module'):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n    return state_dict\n\n\ndef load_checkpoint(model, checkpoint_path, strict=True):\n    state_dict = load_state_dict(checkpoint_path)\n    # detect old format and make compatible with new format\n    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):\n        state_dict = convert_to_custom_text_state_dict(state_dict)\n    resize_pos_embed(state_dict, model)\n    incompatible_keys = model.load_state_dict(state_dict, strict=strict)\n    return incompatible_keys\n\n\ndef create_model(\n        model_name: str,\n        img_size: int,\n        pretrained: Optional[str] = None,\n        precision: str = 'fp32',\n        device: Union[str, torch.device] = 'cpu',\n        jit: bool = False,\n        force_quick_gelu: bool = False,\n        force_custom_text: bool = False,\n        force_patch_dropout: Optional[float] = None,\n        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n        pretrained_image: bool = False,\n        pretrained_hf: bool = True,\n        cache_dir: Optional[str] = None,\n        output_dict: Optional[bool] = None,\n        require_pretrained: bool = False,\n):\n    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)\n    if has_hf_hub_prefix:\n        model_id = model_name[len(HF_HUB_PREFIX):]\n        checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n        config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)\n\n        with open(config_path, 'r', encoding='utf-8') as f:\n            config = json.load(f)\n        pretrained_cfg = config['preprocess_cfg']\n        model_cfg = config['model_cfg']\n    else:\n        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names\n        checkpoint_path = None\n        pretrained_cfg = {}\n        model_cfg = None\n\n    if isinstance(device, str):\n        device = torch.device(device)\n\n    if pretrained and pretrained.lower() == 'openai':\n        logging.info(f'Loading pretrained {model_name} from OpenAI.')\n        model_cfg = model_cfg or get_model_config(model_name)\n        if model_cfg['vision_cfg']['image_size'] != img_size:\n            model_cfg['vision_cfg']['image_size'] = img_size\n            cast_dtype = get_cast_dtype(precision)\n\n            model_pre = load_openai_model(\n                model_name,\n                precision=precision,\n                device=device,\n                jit=jit,\n                cache_dir=cache_dir,\n            )\n            state_dict = model_pre.state_dict()\n\n            # to always output dict even if it is clip\n            if output_dict and hasattr(model_pre, \"output_dict\"):\n                model_pre.output_dict = True\n\n            model = CLIP(**model_cfg, cast_dtype=cast_dtype)\n            ### for resnet\n            if not hasattr(model.visual, 'grid_size'):\n                model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1))\n            resize_pos_embed(state_dict, model)\n            incompatible_keys = model.load_state_dict(state_dict, strict=True)\n            model.to(device=device)\n            if precision in (\"fp16\", \"bf16\"):\n                convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)\n\n            # set image / mean metadata from pretrained_cfg if available, or use default\n            model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN\n            model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD\n\n            # to always output dict even if it is clip\n            if output_dict and hasattr(model, \"output_dict\"):\n                model.output_dict = True\n\n            if jit:\n                model = torch.jit.script(model)\n        else:\n            model = load_openai_model(\n                model_name,\n                precision=precision,\n                device=device,\n                jit=jit,\n                cache_dir=cache_dir,\n            )\n\n            # to always output dict even if it is clip\n            if output_dict and hasattr(model, \"output_dict\"):\n                model.output_dict = True\n    else:\n        model_cfg = model_cfg or get_model_config(model_name)\n        model_cfg['vision_cfg']['image_size'] = img_size\n        if model_cfg is not None:\n            logging.info(f'Loaded {model_name} model config.')\n            pass\n        else:\n            logging.error(f'Model config for {model_name} not found; available models {list_models()}.')\n            raise RuntimeError(f'Model config for {model_name} not found.')\n\n        if force_quick_gelu:\n            # override for use of QuickGELU on non-OpenAI transformer models\n            model_cfg[\"quick_gelu\"] = True\n\n        if force_patch_dropout is not None:\n            # override the default patch dropout value\n            model_cfg[\"vision_cfg\"][\"patch_dropout\"] = force_patch_dropout\n\n        if force_image_size is not None:\n            # override model config's image size\n            model_cfg[\"vision_cfg\"][\"image_size\"] = force_image_size\n\n        if pretrained_image:\n            if 'timm_model_name' in model_cfg.get('vision_cfg', {}):\n                # pretrained weight loading for timm models set via vision_cfg\n                model_cfg['vision_cfg']['timm_model_pretrained'] = True\n            else:\n                assert False, 'pretrained image towers currently only supported for timm models'\n\n        cast_dtype = get_cast_dtype(precision)\n        is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})\n        custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model\n\n        if custom_text:\n            if is_hf_model:\n                model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf\n            if \"coca\" in model_name:\n                model = CoCa(**model_cfg, cast_dtype=cast_dtype)\n            else:\n                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)\n        else:\n            model = CLIP(**model_cfg, cast_dtype=cast_dtype)\n\n        pretrained_loaded = False\n        if pretrained:\n            checkpoint_path = ''\n            pretrained_cfg = get_pretrained_cfg(model_name, pretrained)\n            if pretrained_cfg:\n                checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)\n            elif os.path.exists(pretrained):\n                checkpoint_path = pretrained\n\n            if checkpoint_path:\n                logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')\n                load_checkpoint(model, checkpoint_path)\n            else:\n                error_str = (\n                    f'Pretrained weights ({pretrained}) not found for model {model_name}.'\n                    f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')\n                logging.warning(error_str)\n                raise RuntimeError(error_str)\n            pretrained_loaded = True\n        elif has_hf_hub_prefix:\n            logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')\n            load_checkpoint(model, checkpoint_path)\n            pretrained_loaded = True\n\n        if require_pretrained and not pretrained_loaded:\n            # callers of create_model_from_pretrained always expect pretrained weights\n            raise RuntimeError(\n                f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')\n\n        model.to(device=device)\n        if precision in (\"fp16\", \"bf16\"):\n            convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)\n\n        # set image / mean metadata from pretrained_cfg if available, or use default\n        model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN\n        model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD\n\n        # to always output dict even if it is clip\n        if output_dict and hasattr(model, \"output_dict\"):\n            model.output_dict = True\n\n        if jit:\n            model = torch.jit.script(model)\n\n    return model\n\n\ndef create_loss(args):\n    if args.distill:\n        return DistillClipLoss(\n            local_loss=args.local_loss,\n            gather_with_grad=args.gather_with_grad,\n            cache_labels=True,\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod,\n        )\n    elif \"coca\" in args.model.lower():\n        return CoCaLoss(\n            caption_loss_weight=args.coca_caption_loss_weight,\n            clip_loss_weight=args.coca_contrastive_loss_weight,\n            local_loss=args.local_loss,\n            gather_with_grad=args.gather_with_grad,\n            cache_labels=True,\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod,\n        )\n    return ClipLoss(\n        local_loss=args.local_loss,\n        gather_with_grad=args.gather_with_grad,\n        cache_labels=True,\n        rank=args.rank,\n        world_size=args.world_size,\n        use_horovod=args.horovod,\n    )\n\n\ndef create_model_and_transforms(\n        model_name: str,\n        img_size: int,\n        pretrained: Optional[str] = None,\n        precision: str = 'fp32',\n        device: Union[str, torch.device] = 'cpu',\n        jit: bool = False,\n        force_quick_gelu: bool = False,\n        force_custom_text: bool = False,\n        force_patch_dropout: Optional[float] = None,\n        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n        pretrained_image: bool = False,\n        pretrained_hf: bool = True,\n        image_mean: Optional[Tuple[float, ...]] = None,\n        image_std: Optional[Tuple[float, ...]] = None,\n        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n        cache_dir: Optional[str] = None,\n        output_dict: Optional[bool] = None,\n):\n    model = create_model(\n        model_name,\n        img_size,\n        pretrained,\n        precision=precision,\n        device=device,\n        jit=jit,\n        force_quick_gelu=force_quick_gelu,\n        force_custom_text=force_custom_text,\n        force_patch_dropout=force_patch_dropout,\n        force_image_size=force_image_size,\n        pretrained_image=pretrained_image,\n        pretrained_hf=pretrained_hf,\n        cache_dir=cache_dir,\n        output_dict=output_dict,\n    )\n\n    image_mean = image_mean or getattr(model.visual, 'image_mean', None)\n    image_std = image_std or getattr(model.visual, 'image_std', None)\n    preprocess_train = image_transform(\n        model.visual.image_size,\n        is_train=True,\n        mean=image_mean,\n        std=image_std,\n        aug_cfg=aug_cfg,\n    )\n    preprocess_val = image_transform(\n        model.visual.image_size,\n        is_train=False,\n        mean=image_mean,\n        std=image_std,\n    )\n\n    return model, preprocess_train, preprocess_val\n\n\ndef create_model_from_pretrained(\n        model_name: str,\n        pretrained: Optional[str] = None,\n        precision: str = 'fp32',\n        device: Union[str, torch.device] = 'cpu',\n        jit: bool = False,\n        force_quick_gelu: bool = False,\n        force_custom_text: bool = False,\n        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n        return_transform: bool = True,\n        image_mean: Optional[Tuple[float, ...]] = None,\n        image_std: Optional[Tuple[float, ...]] = None,\n        cache_dir: Optional[str] = None,\n):\n    model = create_model(\n        model_name,\n        pretrained,\n        precision=precision,\n        device=device,\n        jit=jit,\n        force_quick_gelu=force_quick_gelu,\n        force_custom_text=force_custom_text,\n        force_image_size=force_image_size,\n        cache_dir=cache_dir,\n        require_pretrained=True,\n    )\n\n    if not return_transform:\n        return model\n\n    image_mean = image_mean or getattr(model.visual, 'image_mean', None)\n    image_std = image_std or getattr(model.visual, 'image_std', None)\n    preprocess = image_transform(\n        model.visual.image_size,\n        is_train=False,\n        mean=image_mean,\n        std=image_std,\n    )\n\n    return model, preprocess\n"
  },
  {
    "path": "open_clip/generation_utils.py",
    "content": ""
  },
  {
    "path": "open_clip/hf_configs.py",
    "content": "# HF architecture dict:\narch_dict = {\n    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n    \"roberta\": {\n        \"config_names\": {\n            \"context_length\": \"max_position_embeddings\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"hidden_size\",\n            \"heads\": \"num_attention_heads\",\n            \"layers\": \"num_hidden_layers\",\n            \"layer_attr\": \"layer\",\n            \"token_embeddings_attr\": \"embeddings\"\n        },\n        \"pooler\": \"mean_pooler\",\n    },\n    # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig\n    \"xlm-roberta\": {\n        \"config_names\": {\n            \"context_length\": \"max_position_embeddings\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"hidden_size\",\n            \"heads\": \"num_attention_heads\",\n            \"layers\": \"num_hidden_layers\",\n            \"layer_attr\": \"layer\",\n            \"token_embeddings_attr\": \"embeddings\"\n        },\n        \"pooler\": \"mean_pooler\",\n    },\n    # https://huggingface.co/docs/transformers/model_doc/mt5#mt5\n    \"mt5\": {\n        \"config_names\": {\n            # unlimited seqlen\n            # https://github.com/google-research/text-to-text-transfer-transformer/issues/273\n            # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374\n            \"context_length\": \"\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"d_model\",\n            \"heads\": \"num_heads\",\n            \"layers\": \"num_layers\",\n            \"layer_attr\": \"block\",\n            \"token_embeddings_attr\": \"embed_tokens\"\n        },\n        \"pooler\": \"mean_pooler\",\n    },\n}\n"
  },
  {
    "path": "open_clip/hf_model.py",
    "content": "\"\"\" huggingface model adapter\n\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.\n\"\"\"\n\nimport re\n\nimport torch\nimport torch.nn as nn\nfrom torch import TensorType\n\ntry:\n    import transformers\n    from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig\n    from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \\\n        BaseModelOutputWithPoolingAndCrossAttentions\nexcept ImportError as e:\n    transformers = None\n\n\n    class BaseModelOutput:\n        pass\n\n\n    class PretrainedConfig:\n        pass\n\nfrom .hf_configs import arch_dict\n\n\n# utils\ndef _camel2snake(s):\n    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()\n\n\n# TODO: ?last - for gpt-like models\n_POOLERS = {}\n\n\ndef register_pooler(cls):\n    \"\"\"Decorator registering pooler class\"\"\"\n    _POOLERS[_camel2snake(cls.__name__)] = cls\n    return cls\n\n\n@register_pooler\nclass MeanPooler(nn.Module):\n    \"\"\"Mean pooling\"\"\"\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)\n        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)\n\n\n@register_pooler\nclass MaxPooler(nn.Module):\n    \"\"\"Max pooling\"\"\"\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)\n        return masked_output.max(1).values\n\n\n@register_pooler\nclass ClsPooler(nn.Module):\n    \"\"\"CLS token pooling\"\"\"\n\n    def __init__(self, use_pooler_output=True):\n        super().__init__()\n        self.cls_token_position = 0\n        self.use_pooler_output = use_pooler_output\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        if (self.use_pooler_output and\n            isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and\n            (x.pooler_output is not None)\n        ):\n            return x.pooler_output\n\n        return x.last_hidden_state[:, self.cls_token_position, :]\n\n\nclass HFTextEncoder(nn.Module):\n    \"\"\"HuggingFace model adapter\"\"\"\n    output_tokens: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            model_name_or_path: str,\n            output_dim: int,\n            config: PretrainedConfig = None,\n            pooler_type: str = None,\n            proj: str = None,\n            pretrained: bool = True,\n            output_tokens: bool = False,\n    ):\n        super().__init__()\n        self.output_tokens = output_tokens\n        self.output_dim = output_dim\n\n        # TODO: find better way to get this information\n        uses_transformer_pooler = (pooler_type == \"cls_pooler\")\n\n        if transformers is None:\n            raise RuntimeError(\"Please `pip install transformers` to use pre-trained HuggingFace models\")\n        if config is None:\n            self.config = AutoConfig.from_pretrained(model_name_or_path)\n            create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (\n                AutoModel.from_config, self.config)\n            # TODO: do all model configs have this attribute? PretrainedConfig does so yes??\n            if hasattr(self.config, \"is_encoder_decoder\") and self.config.is_encoder_decoder:\n                self.transformer = create_func(model_args)\n                self.transformer = self.transformer.encoder\n            else:\n                self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)\n        else:\n            self.config = config\n            self.transformer = AutoModel.from_config(config)\n        if pooler_type is None:  # get default arch pooler\n            pooler_type = (arch_dict[self.config.model_type][\"pooler\"])\n        \n        self.pooler = _POOLERS[pooler_type]()\n\n        d_model = getattr(self.config, arch_dict[self.config.model_type][\"config_names\"][\"width\"])\n        if (d_model == output_dim) and (proj is None):  # do we always need a proj?\n            self.proj = nn.Identity()\n        elif proj == 'linear':\n            self.proj = nn.Linear(d_model, output_dim, bias=False)\n        elif proj == 'mlp':\n            hidden_size = (d_model + output_dim) // 2\n            self.proj = nn.Sequential(\n                nn.Linear(d_model, hidden_size, bias=False),\n                nn.GELU(),\n                nn.Linear(hidden_size, output_dim, bias=False),\n            )\n\n    def forward(self, x: TensorType):\n        attn_mask = (x != self.config.pad_token_id).long()\n        out = self.transformer(input_ids=x, attention_mask=attn_mask)\n        pooled_out = self.pooler(out, attn_mask)\n        projected = self.proj(pooled_out)\n\n        seq_len = out.last_hidden_state.shape[1]\n        tokens = (\n            out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] \n            if type(self.pooler) == ClsPooler \n            else out.last_hidden_state\n        )\n        \n        if self.output_tokens:\n            return projected, tokens\n        return projected\n\n    def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n        if not unlocked_layers:  # full freezing\n            for n, p in self.transformer.named_parameters():\n                p.requires_grad = (not freeze_layer_norm) if \"LayerNorm\" in n.split(\".\") else False\n            return\n\n        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer\n        layer_list = getattr(encoder, arch_dict[self.config.model_type][\"config_names\"][\"layer_attr\"])\n        print(f\"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model\")\n        embeddings = getattr(\n            self.transformer, arch_dict[self.config.model_type][\"config_names\"][\"token_embeddings_attr\"])\n        modules = [embeddings, *layer_list][:-unlocked_layers]\n        # freeze layers\n        for module in modules:\n            for n, p in module.named_parameters():\n                p.requires_grad = (not freeze_layer_norm) if \"LayerNorm\" in n.split(\".\") else False\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.transformer.gradient_checkpointing_enable()\n\n    def init_parameters(self):\n        pass\n"
  },
  {
    "path": "open_clip/loss.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\ntry:\n    import torch.distributed.nn\n    from torch import distributed as dist\n\n    has_distributed = True\nexcept ImportError:\n    has_distributed = False\n\ntry:\n    import horovod.torch as hvd\nexcept ImportError:\n    hvd = None\n\n\ndef gather_features(\n        image_features,\n        text_features,\n        local_loss=False,\n        gather_with_grad=False,\n        rank=0,\n        world_size=1,\n        use_horovod=False\n):\n    assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'\n    if use_horovod:\n        assert hvd is not None, 'Please install horovod'\n        if gather_with_grad:\n            all_image_features = hvd.allgather(image_features)\n            all_text_features = hvd.allgather(text_features)\n        else:\n            with torch.no_grad():\n                all_image_features = hvd.allgather(image_features)\n                all_text_features = hvd.allgather(text_features)\n            if not local_loss:\n                # ensure grads for local rank when all_* features don't have a gradient\n                gathered_image_features = list(all_image_features.chunk(world_size, dim=0))\n                gathered_text_features = list(all_text_features.chunk(world_size, dim=0))\n                gathered_image_features[rank] = image_features\n                gathered_text_features[rank] = text_features\n                all_image_features = torch.cat(gathered_image_features, dim=0)\n                all_text_features = torch.cat(gathered_text_features, dim=0)\n    else:\n        # We gather tensors from all gpus\n        if gather_with_grad:\n            all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)\n            all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)\n        else:\n            gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]\n            gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]\n            dist.all_gather(gathered_image_features, image_features)\n            dist.all_gather(gathered_text_features, text_features)\n            if not local_loss:\n                # ensure grads for local rank when all_* features don't have a gradient\n                gathered_image_features[rank] = image_features\n                gathered_text_features[rank] = text_features\n            all_image_features = torch.cat(gathered_image_features, dim=0)\n            all_text_features = torch.cat(gathered_text_features, dim=0)\n\n    return all_image_features, all_text_features\n\n\nclass ClipLoss(nn.Module):\n\n    def __init__(\n            self,\n            local_loss=False,\n            gather_with_grad=False,\n            cache_labels=False,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.local_loss = local_loss\n        self.gather_with_grad = gather_with_grad\n        self.cache_labels = cache_labels\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.prev_num_logits = 0\n        self.labels = {}\n\n    def get_ground_truth(self, device, num_logits) -> torch.Tensor:\n        # calculated ground-truth and cache if enabled\n        if self.prev_num_logits != num_logits or device not in self.labels:\n            labels = torch.arange(num_logits, device=device, dtype=torch.long)\n            if self.world_size > 1 and self.local_loss:\n                labels = labels + num_logits * self.rank\n            if self.cache_labels:\n                self.labels[device] = labels\n                self.prev_num_logits = num_logits\n        else:\n            labels = self.labels[device]\n        return labels\n\n    def get_logits(self, image_features, text_features, logit_scale):\n        if self.world_size > 1:\n            all_image_features, all_text_features = gather_features(\n                image_features, text_features,\n                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)\n\n            if self.local_loss:\n                logits_per_image = logit_scale * image_features @ all_text_features.T\n                logits_per_text = logit_scale * text_features @ all_image_features.T\n            else:\n                logits_per_image = logit_scale * all_image_features @ all_text_features.T\n                logits_per_text = logits_per_image.T\n        else:\n            logits_per_image = logit_scale * image_features @ text_features.T\n            logits_per_text = logit_scale * text_features @ image_features.T\n        \n        return logits_per_image, logits_per_text\n\n    def forward(self, image_features, text_features, logit_scale, output_dict=False):\n        device = image_features.device\n        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)\n\n        labels = self.get_ground_truth(device, logits_per_image.shape[0])\n\n        total_loss = (\n            F.cross_entropy(logits_per_image, labels) +\n            F.cross_entropy(logits_per_text, labels)\n        ) / 2\n\n        return {\"contrastive_loss\": total_loss} if output_dict else total_loss\n\n\nclass CoCaLoss(ClipLoss):\n    def __init__(\n            self,\n            caption_loss_weight,\n            clip_loss_weight,\n            pad_id=0,  # pad_token for open_clip custom tokenizer\n            local_loss=False,\n            gather_with_grad=False,\n            cache_labels=False,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__(\n            local_loss=local_loss,\n            gather_with_grad=gather_with_grad,\n            cache_labels=cache_labels,\n            rank=rank,\n            world_size=world_size,\n            use_horovod=use_horovod\n        )\n\n        self.clip_loss_weight = clip_loss_weight\n        self.caption_loss_weight = caption_loss_weight\n        self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)\n\n    def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):\n        clip_loss = super().forward(image_features, text_features, logit_scale)\n        clip_loss = self.clip_loss_weight * clip_loss\n\n        caption_loss = self.caption_loss(\n            logits.permute(0, 2, 1),\n            labels,\n        )\n        caption_loss = caption_loss * self.caption_loss_weight\n\n        if output_dict:\n            return {\"contrastive_loss\": clip_loss, \"caption_loss\": caption_loss}\n\n        return clip_loss, caption_loss\n\n\nclass DistillClipLoss(ClipLoss):\n\n    def dist_loss(self, teacher_logits, student_logits):\n        return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)\n\n    def forward(\n            self,\n            image_features,\n            text_features,\n            logit_scale,\n            dist_image_features,\n            dist_text_features,\n            dist_logit_scale,\n            output_dict=False,\n    ):\n        logits_per_image, logits_per_text = \\\n            self.get_logits(image_features, text_features, logit_scale)\n\n        dist_logits_per_image, dist_logits_per_text = \\\n            self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)\n\n        labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])\n\n        contrastive_loss = (\n            F.cross_entropy(logits_per_image, labels) +\n            F.cross_entropy(logits_per_text, labels)\n        ) / 2\n\n        distill_loss = (\n            self.dist_loss(dist_logits_per_image, logits_per_image) +\n            self.dist_loss(dist_logits_per_text, logits_per_text)\n        ) / 2\n\n        if output_dict:\n            return {\"contrastive_loss\": contrastive_loss, \"distill_loss\": distill_loss}\n\n        return contrastive_loss, distill_loss\n"
  },
  {
    "path": "open_clip/model.py",
    "content": "\"\"\" CLIP Model\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nfrom dataclasses import dataclass\nimport logging\nimport math\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.utils.checkpoint import checkpoint\n\nfrom .hf_model import HFTextEncoder\nfrom .modified_resnet import ModifiedResNet\nfrom .timm_model import TimmModel\nfrom .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer\nfrom .utils import to_2tuple\n\n\n@dataclass\nclass CLIPVisionCfg:\n    layers: Union[Tuple[int, int, int, int], int] = 12\n    width: int = 768\n    head_width: int = 64\n    mlp_ratio: float = 4.0\n    patch_size: int = 16\n    image_size: Union[Tuple[int, int], int] = 224\n    ls_init_value: Optional[float] = None  # layer scale initial value\n    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results\n    input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design\n    global_average_pool: bool = False  # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)\n    attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer\n    n_queries: int = 256 # n_queries for attentional pooler\n    attn_pooler_heads: int = 8 # n heads for attentional_pooling\n    timm_model_name: str = None  # a valid model name overrides layers, width, patch_size\n    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model\n    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\n    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')\n    timm_proj_bias: bool = False  # enable bias final projection\n    timm_drop: float = 0.  # head dropout\n    timm_drop_path: Optional[float] = None  # backbone stochastic depth\n    output_tokens: bool = False\n\n\n@dataclass\nclass CLIPTextCfg:\n    context_length: int = 77\n    vocab_size: int = 49408\n    width: int = 512\n    heads: int = 8\n    layers: int = 12\n    ls_init_value: Optional[float] = None  # layer scale initial value\n    hf_model_name: str = None\n    hf_tokenizer_name: str = None\n    hf_model_pretrained: bool = True\n    proj: str = 'mlp'\n    pooler_type: str = 'mean_pooler'\n    embed_cls: bool = False\n    pad_id: int = 0\n    output_tokens: bool = False\n\n\ndef get_cast_dtype(precision: str):\n    cast_dtype = None\n    if precision == 'bf16':\n        cast_dtype = torch.bfloat16\n    elif precision == 'fp16':\n        cast_dtype = torch.float16\n    return cast_dtype\n\n\ndef _build_vision_tower(\n        embed_dim: int,\n        vision_cfg: CLIPVisionCfg,\n        quick_gelu: bool = False,\n        cast_dtype: Optional[torch.dtype] = None\n):\n    if isinstance(vision_cfg, dict):\n        vision_cfg = CLIPVisionCfg(**vision_cfg)\n\n    # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more\n    # memory efficient in recent PyTorch releases (>= 1.10).\n    # NOTE: timm models always use native GELU regardless of quick_gelu flag.\n    act_layer = QuickGELU if quick_gelu else nn.GELU\n\n    if vision_cfg.timm_model_name:\n        visual = TimmModel(\n            vision_cfg.timm_model_name,\n            pretrained=vision_cfg.timm_model_pretrained,\n            pool=vision_cfg.timm_pool,\n            proj=vision_cfg.timm_proj,\n            proj_bias=vision_cfg.timm_proj_bias,\n            drop=vision_cfg.timm_drop,\n            drop_path=vision_cfg.timm_drop_path,\n            embed_dim=embed_dim,\n            image_size=vision_cfg.image_size,\n        )\n        act_layer = nn.GELU  # so that text transformer doesn't use QuickGELU w/ timm models\n    elif isinstance(vision_cfg.layers, (tuple, list)):\n        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width\n        visual = ModifiedResNet(\n            layers=vision_cfg.layers,\n            output_dim=embed_dim,\n            heads=vision_heads,\n            image_size=vision_cfg.image_size,\n            width=vision_cfg.width,\n        )\n    else:\n        vision_heads = vision_cfg.width // vision_cfg.head_width\n        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n        visual = VisionTransformer(\n            image_size=vision_cfg.image_size,\n            patch_size=vision_cfg.patch_size,\n            width=vision_cfg.width,\n            layers=vision_cfg.layers,\n            heads=vision_heads,\n            mlp_ratio=vision_cfg.mlp_ratio,\n            ls_init_value=vision_cfg.ls_init_value,\n            patch_dropout=vision_cfg.patch_dropout,\n            input_patchnorm=vision_cfg.input_patchnorm,\n            global_average_pool=vision_cfg.global_average_pool,\n            attentional_pool=vision_cfg.attentional_pool,\n            n_queries=vision_cfg.n_queries,\n            attn_pooler_heads=vision_cfg.attn_pooler_heads,\n            output_tokens=vision_cfg.output_tokens,\n            output_dim=embed_dim,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n\n    return visual\n\n\ndef _build_text_tower(\n        embed_dim: int,\n        text_cfg: CLIPTextCfg,\n        quick_gelu: bool = False,\n        cast_dtype: Optional[torch.dtype] = None,\n):\n    if isinstance(text_cfg, dict):\n        text_cfg = CLIPTextCfg(**text_cfg)\n\n    if text_cfg.hf_model_name:\n        text = HFTextEncoder(\n            text_cfg.hf_model_name,\n            output_dim=embed_dim,\n            proj=text_cfg.proj,\n            pooler_type=text_cfg.pooler_type,\n            pretrained=text_cfg.hf_model_pretrained,\n            output_tokens=text_cfg.output_tokens,\n        )\n    else:\n        act_layer = QuickGELU if quick_gelu else nn.GELU\n        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n\n        text = TextTransformer(\n            context_length=text_cfg.context_length,\n            vocab_size=text_cfg.vocab_size,\n            width=text_cfg.width,\n            heads=text_cfg.heads,\n            layers=text_cfg.layers,\n            ls_init_value=text_cfg.ls_init_value,\n            output_dim=embed_dim,\n            embed_cls=text_cfg.embed_cls,\n            output_tokens=text_cfg.output_tokens,\n            pad_id=text_cfg.pad_id,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n    return text\n\n\nclass CLIP(nn.Module):\n    output_dict: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            embed_dim: int,\n            vision_cfg: CLIPVisionCfg,\n            text_cfg: CLIPTextCfg,\n            quick_gelu: bool = False,\n            cast_dtype: Optional[torch.dtype] = None,\n            output_dict: bool = False,\n    ):\n        super().__init__()\n        self.output_dict = output_dict\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n\n        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n        self.transformer = text.transformer\n        self.vocab_size = text.vocab_size\n        self.token_embedding = text.token_embedding\n        self.positional_embedding = text.positional_embedding\n        self.ln_final = text.ln_final\n        self.text_projection = text.text_projection\n        self.register_buffer('attn_mask', text.attn_mask, persistent=False)\n\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.visual.set_grad_checkpointing(enable)\n        self.transformer.grad_checkpointing = enable\n\n    def encode_image(self, image, out_layers, normalize: bool = False):\n        features = self.visual(image, out_layers)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def encode_text(self, text, normalize: bool = False):\n        cast_dtype = self.transformer.get_cast_dtype()\n\n        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]\n\n        x = x + self.positional_embedding.to(cast_dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\n        return F.normalize(x, dim=-1) if normalize else x\n\n    def forward(self, image, text):\n        image_features = self.encode_image(image, normalize=True)\n        text_features = self.encode_text(text, normalize=True)\n        if self.output_dict:\n            return {\n                \"image_features\": image_features,\n                \"text_features\": text_features,\n                \"logit_scale\": self.logit_scale.exp()\n            }\n        return image_features, text_features, self.logit_scale.exp()\n\n\nclass CustomTextCLIP(nn.Module):\n    output_dict: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            embed_dim: int,\n            vision_cfg: CLIPVisionCfg,\n            text_cfg: CLIPTextCfg,\n            quick_gelu: bool = False,\n            cast_dtype: Optional[torch.dtype] = None,\n            output_dict: bool = False,\n    ):\n        super().__init__()\n        self.output_dict = output_dict\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n\n    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n        self.text.lock(unlocked_layers, freeze_layer_norm)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.visual.set_grad_checkpointing(enable)\n        self.text.set_grad_checkpointing(enable)\n\n    def encode_image(self, image, normalize: bool = False):\n        features = self.visual(image)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def encode_text(self, text, normalize: bool = False):\n        features = self.text(text)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def forward(self, image, text):\n        image_features = self.encode_image(image, normalize=True)\n        text_features = self.encode_text(text, normalize=True)\n        if self.output_dict:\n            return {\n                \"image_features\": image_features,\n                \"text_features\": text_features,\n                \"logit_scale\": self.logit_scale.exp()\n            }\n        return image_features, text_features, self.logit_scale.exp()\n\n\ndef convert_weights_to_lp(model: nn.Module, dtype=torch.float16):\n    \"\"\"Convert applicable model parameters to low-precision (bf16 or fp16)\"\"\"\n\n    def _convert_weights(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.to(dtype)\n            if l.bias is not None:\n                l.bias.data = l.bias.data.to(dtype)\n\n        if isinstance(l, (nn.MultiheadAttention, Attention)):\n            for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n                tensor = getattr(l, attr)\n                if tensor is not None:\n                    tensor.data = tensor.data.to(dtype)\n\n        for name in [\"text_projection\", \"proj\"]:\n            if hasattr(l, name):\n                attr = getattr(l, name)\n                if attr is not None:\n                    attr.data = attr.data.to(dtype)\n\n    model.apply(_convert_weights)\n\n\nconvert_weights_to_fp16 = convert_weights_to_lp  # backwards compat\n\n\n# used to maintain checkpoint compatibility\ndef convert_to_custom_text_state_dict(state_dict: dict):\n    if 'text_projection' in state_dict:\n        # old format state_dict, move text tower -> .text\n        new_state_dict = {}\n        for k, v in state_dict.items():\n            if any(k.startswith(p) for p in (\n                'text_projection',\n                'positional_embedding',\n                'token_embedding',\n                'transformer',\n                'ln_final',\n            )):\n                k = 'text.' + k\n            new_state_dict[k] = v\n        return new_state_dict\n    return state_dict\n\n\ndef build_model_from_openai_state_dict(\n        state_dict: dict,\n        quick_gelu=True,\n        cast_dtype=torch.float16,\n):\n    vit = \"visual.proj\" in state_dict\n\n    if vit:\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n        vision_layers = len(\n            [k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n        grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\n        image_size = vision_patch_size * grid_size\n    else:\n        counts: list = [\n            len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\n        vision_layers = tuple(counts)\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n        output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\n        vision_patch_size = None\n        assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n        image_size = output_width * 32\n\n    embed_dim = state_dict[\"text_projection\"].shape[1]\n    context_length = state_dict[\"positional_embedding\"].shape[0]\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n    transformer_heads = transformer_width // 64\n    transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\n\n    vision_cfg = CLIPVisionCfg(\n        layers=vision_layers,\n        width=vision_width,\n        patch_size=vision_patch_size,\n        image_size=image_size,\n    )\n    text_cfg = CLIPTextCfg(\n        context_length=context_length,\n        vocab_size=vocab_size,\n        width=transformer_width,\n        heads=transformer_heads,\n        layers=transformer_layers,\n    )\n    model = CLIP(\n        embed_dim,\n        vision_cfg=vision_cfg,\n        text_cfg=text_cfg,\n        quick_gelu=quick_gelu,  # OpenAI models were trained with QuickGELU\n        cast_dtype=cast_dtype,\n    )\n\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n        state_dict.pop(key, None)\n\n    convert_weights_to_fp16(model)  # OpenAI state dicts are partially converted to float16\n    model.load_state_dict(state_dict)\n    return model.eval()\n\n\ndef trace_model(model, batch_size=256, device=torch.device('cpu')):\n    model.eval()\n    image_size = model.visual.image_size\n    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)\n    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)\n    model = torch.jit.trace_module(\n        model,\n        inputs=dict(\n            forward=(example_images, example_text),\n            encode_text=(example_text,),\n            encode_image=(example_images,)\n        ))\n    model.visual.image_size = image_size\n    return model\n\n\ndef resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):\n    # Rescale the grid of position embeddings when loading from state_dict\n    flag = 1\n    old_pos_embed = state_dict.get('visual.positional_embedding', None)\n    if old_pos_embed is None:\n        flag = 0\n        old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None)\n    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):\n        return\n    grid_size = to_2tuple(model.visual.grid_size)\n    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)\n    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens\n    if new_seq_len == old_pos_embed.shape[0]:\n        return\n\n    if extra_tokens:\n        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]\n    else:\n        pos_emb_tok, pos_emb_img = None, old_pos_embed\n    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))\n\n    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)\n    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)\n    pos_emb_img = F.interpolate(\n        pos_emb_img,\n        size=grid_size,\n        mode=interpolation,\n        antialias=antialias,\n        align_corners=False,\n    )\n    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]\n    if pos_emb_tok is not None:\n        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)\n    else:\n        new_pos_embed = pos_emb_img\n    if flag:\n        state_dict['visual.positional_embedding'] = new_pos_embed\n    else:\n        state_dict['visual.attnpool.positional_embedding'] = new_pos_embed\n"
  },
  {
    "path": "open_clip/model_configs/RN101-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            23,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/RN101.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            23,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/RN50-quickgelu.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            6,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "open_clip/model_configs/RN50.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            6,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/RN50x16.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"layers\": [\n            6,\n            8,\n            18,\n            8\n        ],\n        \"width\": 96,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/RN50x4.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": [\n            4,\n            6,\n            10,\n            6\n        ],\n        \"width\": 80,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/RN50x64.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 448,\n        \"layers\": [\n            3,\n            15,\n            36,\n            10\n        ],\n        \"width\": 128,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-B-16-plus-240.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 240,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-B-16-plus.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-B-16.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-B-32-plus-256.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-B-32-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-H-16.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-L-14-280.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 280,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-L-14-336.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-L-14.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-L-16-320.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 320,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-L-16.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-M-16-alt.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 16,\n        \"ls_init_value\": 1e-4\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-M-16.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-M-32-alt.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-M-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-S-16-alt.json",
    "content": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 256,\n        \"heads\": 4,\n        \"layers\": 10\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-S-16.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-S-32-alt.json",
    "content": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 256,\n        \"heads\": 4,\n        \"layers\": 10\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-S-32.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-bigG-14.json",
    "content": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 48,\n        \"width\": 1664,\n        \"head_width\": 104,\n        \"mlp_ratio\": 4.9231,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 32\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-e-14.json",
    "content": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 56,\n        \"width\": 1792,\n        \"head_width\": 112,\n        \"mlp_ratio\": 8.5715,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 36\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/ViT-g-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 40,\n        \"width\": 1408,\n        \"head_width\": 88,\n        \"mlp_ratio\": 4.3637,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/coca_ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32,\n        \"attentional_pool\": true,\n        \"attn_pooler_heads\": 8,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"embed_cls\": true,\n        \"output_tokens\": true\n    },\n    \"multimodal_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"attn_pooler_heads\": 8\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "open_clip/model_configs/coca_ViT-L-14.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14,\n        \"attentional_pool\": true,\n        \"attn_pooler_heads\": 8,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"embed_cls\": true,\n        \"output_tokens\": true\n    },\n    \"multimodal_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"attn_pooler_heads\": 12\n    },\n    \"custom_text\": true\n}\n"
  },
  {
    "path": "open_clip/model_configs/coca_base.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"multimodal_cfg\": {\n        \"width\": 768,\n        \"context_length\": 76,\n        \"vocab_size\": 64000,\n        \"mlp_ratio\": 4,\n        \"layers\": 12,\n        \"dim_head\": 64,\n        \"heads\": 12,\n        \"n_queries\": 256,\n        \"attn_pooler_heads\": 8\n    },\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 18,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 64000,\n        \"layers\": 12,\n        \"heads\": 12,\n        \"width\": 768,\n        \"embed_cls\": true,\n        \"output_tokens\": true\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "open_clip/model_configs/coca_roberta-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"roberta-base\",\n        \"hf_tokenizer_name\": \"roberta-base\",\n        \"proj\": \"linear\",\n        \"width\": 768,\n        \"output_tokens\": true\n    },\n    \"multimodal_cfg\": {\n        \"context_length\": 76,\n        \"width\": 768,\n        \"heads\": 8,\n        \"layers\": 12\n    },\n    \"custom_text\": true\n}\n"
  },
  {
    "path": "open_clip/model_configs/convnext_base.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_base_w.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_base_w_320.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 320\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_large.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_large_d.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"mlp\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 16\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_large_d_320.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"mlp\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 320\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 16\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_small.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_small\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_tiny.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_tiny\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_xlarge.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xlarge\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 20\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_xxlarge.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/convnext_xxlarge_320.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 320\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/mt5-base-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"google/mt5-base\",\n        \"hf_tokenizer_name\": \"google/mt5-base\",\n        \"proj\": \"mlp\",\n        \"pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "open_clip/model_configs/mt5-xl-ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"google/mt5-xl\",\n        \"hf_tokenizer_name\": \"google/mt5-xl\",\n        \"proj\": \"mlp\",\n        \"pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "open_clip/model_configs/roberta-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"roberta-base\",\n        \"hf_tokenizer_name\": \"roberta-base\",\n        \"proj\": \"mlp\",\n        \"pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "open_clip/model_configs/swin_base_patch4_window7_224.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"swin_base_patch4_window7_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/vit_medium_patch16_gap_256.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_medium_patch16_gap_256\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_relpos_medium_patch16_cls_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "open_clip/model_configs/xlm-roberta-base-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"xlm-roberta-base\",\n        \"hf_tokenizer_name\": \"xlm-roberta-base\",\n        \"proj\": \"mlp\",\n        \"pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "open_clip/model_configs/xlm-roberta-large-ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"xlm-roberta-large\",\n        \"hf_tokenizer_name\": \"xlm-roberta-large\",\n        \"proj\": \"mlp\",\n        \"pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "open_clip/modified_resnet.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom open_clip.utils import freeze_batch_norm_2d\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.act1 = nn.ReLU(inplace=True)\n\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.act2 = nn.ReLU(inplace=True)\n\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.act3 = nn.ReLU(inplace=True)\n\n        self.downsample = None\n        self.stride = stride\n\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n            self.downsample = nn.Sequential(OrderedDict([\n                (\"-1\", nn.AvgPool2d(stride)),\n                (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\n                (\"1\", nn.BatchNorm2d(planes * self.expansion))\n            ]))\n\n    def forward(self, x: torch.Tensor):\n        identity = x\n\n        out = self.act1(self.bn1(self.conv1(x)))\n        out = self.act2(self.bn2(self.conv2(out)))\n        out = self.avgpool(out)\n        out = self.bn3(self.conv3(out))\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.act3(out)\n        return out\n\n\nclass AttentionPool2d(nn.Module):\n    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x, key=x, value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0.,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False\n        )\n\n        return x[0]\n\n\nclass ModifiedResNet(nn.Module):\n    \"\"\"\n    A ResNet class that is similar to torchvision's but contains the following changes:\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n    - The final pooling layer is a QKV attention instead of an average pool\n    \"\"\"\n\n    def __init__(self, layers, output_dim, heads, image_size=224, width=64):\n        super().__init__()\n        self.output_dim = output_dim\n        self.image_size = image_size\n\n        # the 3-layer stem\n        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(width // 2)\n        self.act1 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(width // 2)\n        self.act2 = nn.ReLU(inplace=True)\n        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(width)\n        self.act3 = nn.ReLU(inplace=True)\n        self.avgpool = nn.AvgPool2d(2)\n\n        # residual layers\n        self._inplanes = width  # this is a *mutable* variable used during construction\n        self.layer1 = self._make_layer(width, layers[0])\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n\n        embed_dim = width * 32  # the ResNet feature dimension\n        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)\n\n        self.init_parameters()\n\n    def _make_layer(self, planes, blocks, stride=1):\n        layers = [Bottleneck(self._inplanes, planes, stride)]\n\n        self._inplanes = planes * Bottleneck.expansion\n        for _ in range(1, blocks):\n            layers.append(Bottleneck(self._inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def init_parameters(self):\n        if self.attnpool is not None:\n            std = self.attnpool.c_proj.in_features ** -0.5\n            nn.init.normal_(self.attnpool.q_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.k_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.v_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.c_proj.weight, std=std)\n\n        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:\n            for name, param in resnet_block.named_parameters():\n                if name.endswith(\"bn3.weight\"):\n                    nn.init.zeros_(param)\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        assert unlocked_groups == 0, 'partial locking not currently supported for this model'\n        for param in self.parameters():\n            param.requires_grad = False\n        if freeze_bn_stats:\n            freeze_batch_norm_2d(self)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        # FIXME support for non-transformer\n        pass\n\n    def stem(self, x):\n        x = self.act1(self.bn1(self.conv1(x)))\n        x = self.act2(self.bn2(self.conv2(x)))\n        x = self.act3(self.bn3(self.conv3(x)))\n        x = self.avgpool(x)\n        return x\n\n    def forward(self, x, out_blocks):\n        x = self.stem(x)\n        x_1 = self.layer1(x)\n        x_2 = self.layer2(x_1)\n        x_3 = self.layer3(x_2)\n        x_4 = self.layer4(x_3)\n        x = self.attnpool(x_4)\n\n        out_tokens = []\n        x_blocks = [x_1, x_2, x_3, x_4]\n        for i in out_blocks:\n            out_tokens.append(x_blocks[i - 1])\n\n        return x, out_tokens\n"
  },
  {
    "path": "open_clip/openai.py",
    "content": "\"\"\" OpenAI pretrained model functions\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\n\nimport os\nimport warnings\nfrom typing import List, Optional, Union\n\nimport torch\n\nfrom .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype\nfrom .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url\n\n__all__ = [\"list_openai_models\", \"load_openai_model\"]\n\n\ndef list_openai_models() -> List[str]:\n    \"\"\"Returns the names of available CLIP models\"\"\"\n    return list_pretrained_models_by_tag('openai')\n\n\ndef load_openai_model(\n        name: str,\n        precision: Optional[str] = None,\n        device: Optional[Union[str, torch.device]] = None,\n        jit: bool = True,\n        cache_dir: Optional[str] = None,\n):\n    \"\"\"Load a CLIP model\n\n    Parameters\n    ----------\n    name : str\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\n    precision: str\n        Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.\n    device : Union[str, torch.device]\n        The device to put the loaded model\n    jit : bool\n        Whether to load the optimized JIT model (default) or more hackable non-JIT model.\n    cache_dir : Optional[str]\n        The directory to cache the downloaded model weights\n\n    Returns\n    -------\n    model : torch.nn.Module\n        The CLIP model\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n    \"\"\"\n    if device is None:\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    if precision is None:\n        precision = 'fp32' if device == 'cpu' else 'fp16'\n\n    if get_pretrained_url(name, 'openai'):\n        model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)\n    elif os.path.isfile(name):\n        model_path = name\n    else:\n        raise RuntimeError(f\"Model {name} not found; available models = {list_openai_models()}\")\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=device if jit else \"cpu\").eval()\n        state_dict = None\n    except RuntimeError:\n        # loading saved state dict\n        if jit:\n            warnings.warn(f\"File {model_path} is not a JIT archive. Loading as a state dict instead\")\n            jit = False\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    if not jit:\n        # Build a non-jit model from the OpenAI jitted model state dict\n        cast_dtype = get_cast_dtype(precision)\n        try:\n            model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)\n        except KeyError:\n            sd = {k[7:]: v for k, v in state_dict[\"state_dict\"].items()}\n            model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)\n\n        # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use\n        model = model.to(device)\n        if precision.startswith('amp') or precision == 'fp32':\n            model.float()\n        elif precision == 'bf16':\n            convert_weights_to_lp(model, dtype=torch.bfloat16)\n\n        return model\n\n    # patch the device names\n    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])\n    device_node = [n for n in device_holder.graph.findAllNodes(\"prim::Constant\") if \"Device\" in repr(n)][-1]\n\n    def patch_device(module):\n        try:\n            graphs = [module.graph] if hasattr(module, \"graph\") else []\n        except RuntimeError:\n            graphs = []\n\n        if hasattr(module, \"forward1\"):\n            graphs.append(module.forward1.graph)\n\n        for graph in graphs:\n            for node in graph.findAllNodes(\"prim::Constant\"):\n                if \"value\" in node.attributeNames() and str(node[\"value\"]).startswith(\"cuda\"):\n                    node.copyAttributes(device_node)\n\n    model.apply(patch_device)\n    patch_device(model.encode_image)\n    patch_device(model.encode_text)\n\n    # patch dtype to float32 (typically for CPU)\n    if precision == 'fp32':\n        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])\n        float_input = list(float_holder.graph.findNode(\"aten::to\").inputs())[1]\n        float_node = float_input.node()\n\n        def patch_float(module):\n            try:\n                graphs = [module.graph] if hasattr(module, \"graph\") else []\n            except RuntimeError:\n                graphs = []\n\n            if hasattr(module, \"forward1\"):\n                graphs.append(module.forward1.graph)\n\n            for graph in graphs:\n                for node in graph.findAllNodes(\"aten::to\"):\n                    inputs = list(node.inputs())\n                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()\n                        if inputs[i].node()[\"value\"] == 5:\n                            inputs[i].node().copyAttributes(float_node)\n\n        model.apply(patch_float)\n        patch_float(model.encode_image)\n        patch_float(model.encode_text)\n        model.float()\n\n    # ensure image_size attr available at consistent location for both jit and non-jit\n    model.visual.image_size = model.input_resolution.item()\n    return model\n"
  },
  {
    "path": "open_clip/pretrained.py",
    "content": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom functools import partial\nfrom typing import Dict, Union\n\nfrom tqdm import tqdm\n\nfrom .version import __version__\n\ntry:\n    from huggingface_hub import hf_hub_download\n    hf_hub_download = partial(hf_hub_download, library_name=\"open_clip\", library_version=__version__)\n    _has_hf_hub = True\nexcept ImportError:\n    hf_hub_download = None\n    _has_hf_hub = False\n\n\ndef _pcfg(url='', hf_hub='', mean=None, std=None):\n    return dict(\n        url=url,\n        hf_hub=hf_hub,\n        mean=mean,\n        std=std,\n    )\n\n\n_RN50 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\"),\n    cc12m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\"),\n)\n\n_RN50_quickgelu = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\"),\n    cc12m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\"),\n)\n\n_RN101 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\"),\n)\n\n_RN101_quickgelu = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\"),\n)\n\n_RN50x4 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\"),\n)\n\n_RN50x16 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\"),\n)\n\n_RN50x64 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt\"),\n)\n\n_VITB32 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\"),\n    laion2b_e16=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth\"),\n    laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')\n)\n\n_VITB32_quickgelu = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\"),\n)\n\n_VITB16 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt\"),\n    # laion400m_32k=_pcfg(\n    #     url=\"\",\n    #     mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),\n    # laion400m_64k=_pcfg(\n    #     url=\"\",\n    #     mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),\n    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),\n)\n\n_VITB16_PLUS_240 = dict(\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt\"),\n)\n\n_VITL14 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt\"),\n    laion2b_s32b_b82k=_pcfg(\n        hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),\n)\n\n_VITL14_336 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\"),\n)\n\n_VITH14 = dict(\n    laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),\n)\n\n_VITg14 = dict(\n    laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),\n    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),\n)\n\n_VITbigG14 = dict(\n    laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),\n)\n\n_robertaViTB32 = dict(\n    laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),\n)\n\n_xlmRobertaBaseViTB32 = dict(\n    laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),\n)\n\n_xlmRobertaLargeFrozenViTH14 = dict(\n    frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),\n)\n\n_convnext_base = dict(\n    laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),\n)\n\n_convnext_base_w = dict(\n    laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),\n    laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),\n    laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),\n)\n\n_convnext_base_w_320 = dict(\n    laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),\n    laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),\n)\n\n_convnext_large_d = dict(\n    laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),\n)\n\n_convnext_large_d_320 = dict(\n    laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),\n    laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),\n)\n\n_convnext_xxlarge = dict(\n    laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),\n    laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),\n    laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),\n)\n\n_coca_VITB32 = dict(\n    laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),\n    mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')\n)\n\n_coca_VITL14 = dict(\n    laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),\n    mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')\n)\n\n\n_PRETRAINED = {\n    \"RN50\": _RN50,\n    \"RN50-quickgelu\": _RN50_quickgelu,\n    \"RN101\": _RN101,\n    \"RN101-quickgelu\": _RN101_quickgelu,\n    \"RN50x4\": _RN50x4,\n    \"RN50x16\": _RN50x16,\n    \"RN50x64\": _RN50x64,\n    \"ViT-B-32\": _VITB32,\n    \"ViT-B-32-quickgelu\": _VITB32_quickgelu,\n    \"ViT-B-16\": _VITB16,\n    \"ViT-B-16-plus-240\": _VITB16_PLUS_240,\n    \"ViT-L-14\": _VITL14,\n    \"ViT-L-14-336\": _VITL14_336,\n    \"ViT-H-14\": _VITH14,\n    \"ViT-g-14\": _VITg14,\n    \"ViT-bigG-14\": _VITbigG14,\n    \"roberta-ViT-B-32\": _robertaViTB32,\n    \"xlm-roberta-base-ViT-B-32\": _xlmRobertaBaseViTB32,\n    \"xlm-roberta-large-ViT-H-14\": _xlmRobertaLargeFrozenViTH14,\n    \"convnext_base\": _convnext_base,\n    \"convnext_base_w\": _convnext_base_w,\n    \"convnext_base_w_320\": _convnext_base_w_320,\n    \"convnext_large_d\": _convnext_large_d,\n    \"convnext_large_d_320\": _convnext_large_d_320,\n    \"convnext_xxlarge\": _convnext_xxlarge,\n    \"coca_ViT-B-32\": _coca_VITB32,\n    \"coca_ViT-L-14\": _coca_VITL14,\n}\n\n\ndef _clean_tag(tag: str):\n    # normalize pretrained tags\n    return tag.lower().replace('-', '_')\n\n\ndef list_pretrained(as_str: bool = False):\n    \"\"\" returns list of pretrained models\n    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True\n    \"\"\"\n    return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]\n\n\ndef list_pretrained_models_by_tag(tag: str):\n    \"\"\" return all models having the specified pretrain tag \"\"\"\n    models = []\n    tag = _clean_tag(tag)\n    for k in _PRETRAINED.keys():\n        if tag in _PRETRAINED[k]:\n            models.append(k)\n    return models\n\n\ndef list_pretrained_tags_by_model(model: str):\n    \"\"\" return all pretrain tags for the specified model architecture \"\"\"\n    tags = []\n    if model in _PRETRAINED:\n        tags.extend(_PRETRAINED[model].keys())\n    return tags\n\n\ndef is_pretrained_cfg(model: str, tag: str):\n    if model not in _PRETRAINED:\n        return False\n    return _clean_tag(tag) in _PRETRAINED[model]\n\n\ndef get_pretrained_cfg(model: str, tag: str):\n    if model not in _PRETRAINED:\n        return {}\n    model_pretrained = _PRETRAINED[model]\n    return model_pretrained.get(_clean_tag(tag), {})\n\n\ndef get_pretrained_url(model: str, tag: str):\n    cfg = get_pretrained_cfg(model, _clean_tag(tag))\n    return cfg.get('url', '')\n\n\ndef download_pretrained_from_url(\n        url: str,\n        cache_dir: Union[str, None] = None,\n):\n    if not cache_dir:\n        cache_dir = os.path.expanduser(\"~/.cache/clip\")\n    os.makedirs(cache_dir, exist_ok=True)\n    filename = os.path.basename(url)\n\n    if 'openaipublic' in url:\n        expected_sha256 = url.split(\"/\")[-2]\n    elif 'mlfoundations' in url:\n        expected_sha256 = os.path.splitext(filename)[0].split(\"-\")[-1]\n    else:\n        expected_sha256 = ''\n\n    download_target = os.path.join(cache_dir, filename)\n\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\n        raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n\n    if os.path.isfile(download_target):\n        if expected_sha256:\n            if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n                return download_target\n            else:\n                warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\n        else:\n            return download_target\n\n    with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n        with tqdm(total=int(source.headers.get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True) as loop:\n            while True:\n                buffer = source.read(8192)\n                if not buffer:\n                    break\n\n                output.write(buffer)\n                loop.update(len(buffer))\n\n    if expected_sha256 and not hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n        raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\n\n    return download_target\n\n\ndef has_hf_hub(necessary=False):\n    if not _has_hf_hub and necessary:\n        # if no HF Hub module installed, and it is necessary to continue, raise error\n        raise RuntimeError(\n            'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')\n    return _has_hf_hub\n\n\ndef download_pretrained_from_hf(\n        model_id: str,\n        filename: str = 'open_clip_pytorch_model.bin',\n        revision=None,\n        cache_dir: Union[str, None] = None,\n):\n    has_hf_hub(True)\n    cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)\n    return cached_file\n\n\ndef download_pretrained(\n        cfg: Dict,\n        force_hf_hub: bool = False,\n        cache_dir: Union[str, None] = None,\n):\n    target = ''\n    if not cfg:\n        return target\n\n    download_url = cfg.get('url', '')\n    download_hf_hub = cfg.get('hf_hub', '')\n    if download_hf_hub and force_hf_hub:\n        # use HF hub even if url exists\n        download_url = ''\n\n    if download_url:\n        target = download_pretrained_from_url(download_url, cache_dir=cache_dir)\n    elif download_hf_hub:\n        has_hf_hub(True)\n        # we assume the hf_hub entries in pretrained config combine model_id + filename in\n        # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and\n        # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.\n        model_id, filename = os.path.split(download_hf_hub)\n        if filename:\n            target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)\n        else:\n            target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n\n    return target\n"
  },
  {
    "path": "open_clip/push_to_hf_hub.py",
    "content": "import argparse\nimport json\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\nfrom typing import Optional, Tuple\n\nimport torch\n\ntry:\n    from huggingface_hub import (\n        create_repo,\n        get_hf_file_metadata,\n        hf_hub_download,\n        hf_hub_url,\n        repo_type_and_id_from_hf_id,\n        upload_folder,\n    )\n    from huggingface_hub.utils import EntryNotFoundError\n    _has_hf_hub = True\nexcept ImportError:\n    _has_hf_hub = False\n\nfrom .factory import create_model_from_pretrained, get_model_config, get_tokenizer\nfrom .tokenizer import HFTokenizer\n\n\ndef save_config_for_hf(\n        model,\n        config_path: str,\n        model_config: Optional[dict]\n):\n    preprocess_cfg = {\n        'mean': model.visual.image_mean,\n        'std': model.visual.image_std,\n    }\n    hf_config = {\n        'model_cfg': model_config,\n        'preprocess_cfg': preprocess_cfg,\n    }\n\n    with config_path.open('w') as f:\n        json.dump(hf_config, f, indent=2)\n\n\ndef save_for_hf(\n    model,\n    tokenizer: HFTokenizer,\n    model_config: dict,\n    save_directory: str,\n    weights_filename='open_clip_pytorch_model.bin',\n    config_filename='open_clip_config.json',\n):\n    save_directory = Path(save_directory)\n    save_directory.mkdir(exist_ok=True, parents=True)\n\n    weights_path = save_directory / weights_filename\n    torch.save(model.state_dict(), weights_path)\n\n    tokenizer.save_pretrained(save_directory)\n\n    config_path = save_directory / config_filename\n    save_config_for_hf(model, config_path, model_config=model_config)\n\n\ndef push_to_hf_hub(\n    model,\n    tokenizer,\n    model_config: Optional[dict],\n    repo_id: str,\n    commit_message: str = 'Add model',\n    token: Optional[str] = None,\n    revision: Optional[str] = None,\n    private: bool = False,\n    create_pr: bool = False,\n    model_card: Optional[dict] = None,\n):\n    if not isinstance(tokenizer, HFTokenizer):\n        # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14\n        tokenizer = HFTokenizer('openai/clip-vit-large-patch14')\n\n    # Create repo if it doesn't exist yet\n    repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)\n\n    # Infer complete repo_id from repo_url\n    # Can be different from the input `repo_id` if repo_owner was implicit\n    _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)\n    repo_id = f\"{repo_owner}/{repo_name}\"\n\n    # Check if README file already exist in repo\n    try:\n        get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename=\"README.md\", revision=revision))\n        has_readme = True\n    except EntryNotFoundError:\n        has_readme = False\n\n    # Dump model and push to Hub\n    with TemporaryDirectory() as tmpdir:\n        # Save model weights and config.\n        save_for_hf(\n            model,\n            tokenizer=tokenizer,\n            model_config=model_config,\n            save_directory=tmpdir,\n        )\n\n        # Add readme if it does not exist\n        if not has_readme:\n            model_card = model_card or {}\n            model_name = repo_id.split('/')[-1]\n            readme_path = Path(tmpdir) / \"README.md\"\n            readme_text = generate_readme(model_card, model_name)\n            readme_path.write_text(readme_text)\n\n        # Upload model and return\n        return upload_folder(\n            repo_id=repo_id,\n            folder_path=tmpdir,\n            revision=revision,\n            create_pr=create_pr,\n            commit_message=commit_message,\n        )\n\n\ndef push_pretrained_to_hf_hub(\n    model_name,\n    pretrained: str,\n    repo_id: str,\n    image_mean: Optional[Tuple[float, ...]] = None,\n    image_std: Optional[Tuple[float, ...]] = None,\n    commit_message: str = 'Add model',\n    token: Optional[str] = None,\n    revision: Optional[str] = None,\n    private: bool = False,\n    create_pr: bool = False,\n    model_card: Optional[dict] = None,\n):\n    model, preprocess_eval = create_model_from_pretrained(\n        model_name,\n        pretrained=pretrained,\n        image_mean=image_mean,\n        image_std=image_std,\n    )\n\n    model_config = get_model_config(model_name)\n    assert model_config\n\n    tokenizer = get_tokenizer(model_name)\n\n    push_to_hf_hub(\n        model=model,\n        tokenizer=tokenizer,\n        model_config=model_config,\n        repo_id=repo_id,\n        commit_message=commit_message,\n        token=token,\n        revision=revision,\n        private=private,\n        create_pr=create_pr,\n        model_card=model_card,\n    )\n\n\ndef generate_readme(model_card: dict, model_name: str):\n    readme_text = \"---\\n\"\n    readme_text += \"tags:\\n- zero-shot-image-classification\\n- clip\\n\"\n    readme_text += \"library_tag: open_clip\\n\"\n    readme_text += f\"license: {model_card.get('license', 'mit')}\\n\"\n    if 'details' in model_card and 'Dataset' in model_card['details']:\n        readme_text += 'datasets:\\n'\n        readme_text += f\"- {model_card['details']['Dataset'].lower()}\\n\"\n    readme_text += \"---\\n\"\n    readme_text += f\"# Model card for {model_name}\\n\"\n    if 'description' in model_card:\n        readme_text += f\"\\n{model_card['description']}\\n\"\n    if 'details' in model_card:\n        readme_text += f\"\\n## Model Details\\n\"\n        for k, v in model_card['details'].items():\n            if isinstance(v, (list, tuple)):\n                readme_text += f\"- **{k}:**\\n\"\n                for vi in v:\n                    readme_text += f\"  - {vi}\\n\"\n            elif isinstance(v, dict):\n                readme_text += f\"- **{k}:**\\n\"\n                for ki, vi in v.items():\n                    readme_text += f\"  - {ki}: {vi}\\n\"\n            else:\n                readme_text += f\"- **{k}:** {v}\\n\"\n    if 'usage' in model_card:\n        readme_text += f\"\\n## Model Usage\\n\"\n        readme_text += model_card['usage']\n        readme_text += '\\n'\n\n    if 'comparison' in model_card:\n        readme_text += f\"\\n## Model Comparison\\n\"\n        readme_text += model_card['comparison']\n        readme_text += '\\n'\n\n    if 'citation' in model_card:\n        readme_text += f\"\\n## Citation\\n\"\n        if not isinstance(model_card['citation'], (list, tuple)):\n            citations = [model_card['citation']]\n        else:\n            citations = model_card['citation']\n        for c in citations:\n            readme_text += f\"```bibtex\\n{c}\\n```\\n\"\n\n    return readme_text\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Push to Hugging Face Hub\")\n    parser.add_argument(\n        \"--model\", type=str, help=\"Name of the model to use.\",\n    )\n    parser.add_argument(\n        \"--pretrained\", type=str,\n        help=\"Use a pretrained CLIP model weights with the specified tag or file path.\",\n    )\n    parser.add_argument(\n        \"--repo-id\", type=str,\n        help=\"Destination HF Hub repo-id ie 'organization/model_id'.\",\n    )\n    parser.add_argument(\n        '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',\n        help='Override default image mean value of dataset')\n    parser.add_argument(\n        '--image-std', type=float, nargs='+', default=None, metavar='STD',\n        help='Override default image std deviation of of dataset')\n    args = parser.parse_args()\n\n    print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')\n\n    # FIXME add support to pass model_card json / template from file via cmd line\n\n    push_pretrained_to_hf_hub(\n        args.model,\n        args.pretrained,\n        args.repo_id,\n        image_mean=args.image_mean,  # override image mean/std if trained w/ non defaults\n        image_std=args.image_std,\n    )\n\n    print(f'{args.model} saved.')\n"
  },
  {
    "path": "open_clip/timm_model.py",
    "content": "\"\"\" timm model adapter\n\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.\n\"\"\"\nimport logging\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\n\ntry:\n    import timm\n    from timm.models.layers import Mlp, to_2tuple\n    try:\n        # old timm imports < 0.8.1\n        from timm.models.layers.attention_pool2d import RotAttentionPool2d\n        from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d\n    except ImportError:\n        # new timm imports >= 0.8.1\n        from timm.layers import RotAttentionPool2d\n        from timm.layers import AttentionPool2d as AbsAttentionPool2d\nexcept ImportError:\n    timm = None\n\nfrom .utils import freeze_batch_norm_2d\n\n\nclass TimmModel(nn.Module):\n    \"\"\" timm model adapter\n    # FIXME this adapter is a work in progress, may change in ways that break weight compat\n    \"\"\"\n\n    def __init__(\n            self,\n            model_name,\n            embed_dim,\n            image_size=224,\n            pool='avg',\n            proj='linear',\n            proj_bias=False,\n            drop=0.,\n            drop_path=None,\n            pretrained=False,\n    ):\n        super().__init__()\n        if timm is None:\n            raise RuntimeError(\"Please `pip install timm` to use timm models.\")\n\n        self.image_size = to_2tuple(image_size)\n        timm_kwargs = {}\n        if drop_path is not None:\n            timm_kwargs['drop_path_rate'] = drop_path\n        self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)\n        feat_size = self.trunk.default_cfg.get('pool_size', None)\n        feature_ndim = 1 if not feat_size else 2\n        if pool in ('abs_attn', 'rot_attn'):\n            assert feature_ndim == 2\n            # if attn pooling used, remove both classifier and default pool\n            self.trunk.reset_classifier(0, global_pool='')\n        else:\n            # reset global pool if pool config set, otherwise leave as network default\n            reset_kwargs = dict(global_pool=pool) if pool else {}\n            self.trunk.reset_classifier(0, **reset_kwargs)\n        prev_chs = self.trunk.num_features\n\n        head_layers = OrderedDict()\n        if pool == 'abs_attn':\n            head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)\n            prev_chs = embed_dim\n        elif pool == 'rot_attn':\n            head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)\n            prev_chs = embed_dim\n        else:\n            assert proj, 'projection layer needed if non-attention pooling is used.'\n\n        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used\n        if proj == 'linear':\n            head_layers['drop'] = nn.Dropout(drop)\n            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)\n        elif proj == 'mlp':\n            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))\n\n        self.head = nn.Sequential(head_layers)\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        \"\"\" lock modules\n        Args:\n            unlocked_groups (int): leave last n layer groups unlocked (default: 0)\n        \"\"\"\n        if not unlocked_groups:\n            # lock full model\n            for param in self.trunk.parameters():\n                param.requires_grad = False\n            if freeze_bn_stats:\n                freeze_batch_norm_2d(self.trunk)\n        else:\n            # NOTE: partial freeze requires latest timm (master) branch and is subject to change\n            try:\n                # FIXME import here until API stable and in an official release\n                from timm.models.helpers import group_parameters, group_modules\n            except ImportError:\n                raise RuntimeError(\n                    'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')\n            matcher = self.trunk.group_matcher()\n            gparams = group_parameters(self.trunk, matcher)\n            max_layer_id = max(gparams.keys())\n            max_layer_id = max_layer_id - unlocked_groups\n            for group_idx in range(max_layer_id + 1):\n                group = gparams[group_idx]\n                for param in group:\n                    self.trunk.get_parameter(param).requires_grad = False\n            if freeze_bn_stats:\n                gmodules = group_modules(self.trunk, matcher, reverse=True)\n                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}\n                freeze_batch_norm_2d(self.trunk, gmodules)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        try:\n            self.trunk.set_grad_checkpointing(enable)\n        except Exception as e:\n            logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')\n\n    def forward(self, x):\n        x = self.trunk(x)\n        x = self.head(x)\n        return x\n"
  },
  {
    "path": "open_clip/tokenizer.py",
    "content": "\"\"\" CLIP tokenizer\n\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nimport gzip\nimport html\nimport os\nfrom functools import lru_cache\nfrom typing import Union, List\n\nimport ftfy\nimport regex as re\nimport torch\n\n# https://stackoverflow.com/q/62691279\nimport os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\n@lru_cache()\ndef default_bpe():\n    return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"bpe_simple_vocab_16e6.txt.gz\")\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a significant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"Return set of symbol pairs in a word.\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r'\\s+', ' ', text)\n    text = text.strip()\n    return text\n\n\nclass SimpleTokenizer(object):\n    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\n        merges = merges[1:49152-256-2+1]\n        merges = [tuple(merge.split()) for merge in merges]\n        vocab = list(bytes_to_unicode().values())\n        vocab = vocab + [v+'</w>' for v in vocab]\n        for merge in merges:\n            vocab.append(''.join(merge))\n        if not special_tokens:\n            special_tokens = ['<start_of_text>', '<end_of_text>']\n        else:\n            special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens\n        vocab.extend(special_tokens)\n        self.encoder = dict(zip(vocab, range(len(vocab))))\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\n        self.cache = {t:t for t in special_tokens}\n        special = \"|\".join(special_tokens)\n        self.pat = re.compile(special + r\"\"\"|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n\n        self.vocab_size = len(self.encoder)\n        self.all_special_ids = [self.encoder[t] for t in special_tokens]\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token[:-1]) + ( token[-1] + '</w>',)\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token+'</w>'\n\n        while True:\n            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                    new_word.extend(word[i:j])\n                    i = j\n                except:\n                    new_word.extend(word[i:])\n                    break\n\n                if word[i] == first and i < len(word)-1 and word[i+1] == second:\n                    new_word.append(first+second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = ' '.join(word)\n        self.cache[token] = word\n        return word\n\n    def encode(self, text):\n        bpe_tokens = []\n        text = whitespace_clean(basic_clean(text)).lower()\n        for token in re.findall(self.pat, text):\n            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n        return bpe_tokens\n\n    def decode(self, tokens):\n        text = ''.join([self.decoder[token] for token in tokens])\n        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n        return text\n\n\n_tokenizer = SimpleTokenizer()\n\ndef decode(output_ids: torch.Tensor):\n    output_ids = output_ids.cpu().numpy()\n    return _tokenizer.decode(output_ids)\n\ndef tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:\n    \"\"\"\n    Returns the tokenized representation of given input string(s)\n\n    Parameters\n    ----------\n    texts : Union[str, List[str]]\n        An input string or a list of input strings to tokenize\n    context_length : int\n        The context length to use; all CLIP models use 77 as the context length\n\n    Returns\n    -------\n    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\n    \"\"\"\n    if isinstance(texts, str):\n        texts = [texts]\n\n    sot_token = _tokenizer.encoder[\"<start_of_text>\"]\n    eot_token = _tokenizer.encoder[\"<end_of_text>\"]\n    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        if len(tokens) > context_length:\n            tokens = tokens[:context_length]  # Truncate\n            tokens[-1] = eot_token\n        result[i, :len(tokens)] = torch.tensor(tokens)\n\n    return result\n\n\nclass HFTokenizer:\n    \"\"\"HuggingFace tokenizer wrapper\"\"\"\n\n    def __init__(self, tokenizer_name: str):\n        from transformers import AutoTokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n\n    def save_pretrained(self, dest):\n        self.tokenizer.save_pretrained(dest)\n\n    def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:\n        # same cleaning as for default tokenizer, except lowercasing\n        # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance\n        if isinstance(texts, str):\n            texts = [texts]\n        texts = [whitespace_clean(basic_clean(text)) for text in texts]\n        input_ids = self.tokenizer(\n            texts,\n            return_tensors='pt',\n            max_length=context_length,\n            padding='max_length',\n            truncation=True,\n        ).input_ids\n        return input_ids\n"
  },
  {
    "path": "open_clip/transform.py",
    "content": "import warnings\nfrom dataclasses import dataclass, asdict\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms.functional as F\n\nfrom torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n    CenterCrop\n\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\n\n\n@dataclass\nclass AugmentationCfg:\n    scale: Tuple[float, float] = (0.9, 1.0)\n    ratio: Optional[Tuple[float, float]] = None\n    color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None\n    interpolation: Optional[str] = None\n    re_prob: Optional[float] = None\n    re_count: Optional[int] = None\n    use_timm: bool = False\n\n\nclass ResizeMaxSize(nn.Module):\n\n    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):\n        super().__init__()\n        if not isinstance(max_size, int):\n            raise TypeError(f\"Size should be int. Got {type(max_size)}\")\n        self.max_size = max_size\n        self.interpolation = interpolation\n        self.fn = min if fn == 'min' else min\n        self.fill = fill\n\n    def forward(self, img):\n        if isinstance(img, torch.Tensor):\n            height, width = img.shape[:2]\n        else:\n            width, height = img.size\n        scale = self.max_size / float(max(height, width))\n        if scale != 1.0:\n            new_size = tuple(round(dim * scale) for dim in (height, width))\n            img = F.resize(img, new_size, self.interpolation)\n            pad_h = self.max_size - new_size[0]\n            pad_w = self.max_size - new_size[1]\n            img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)\n        return img\n\n\ndef _convert_to_rgb(image):\n    return image.convert('RGB')\n\n\ndef image_transform(\n        image_size: int,\n        is_train: bool,\n        mean: Optional[Tuple[float, ...]] = None,\n        std: Optional[Tuple[float, ...]] = None,\n        resize_longest_max: bool = False,\n        fill_color: int = 0,\n        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n):\n    mean = mean or OPENAI_DATASET_MEAN\n    if not isinstance(mean, (list, tuple)):\n        mean = (mean,) * 3\n\n    std = std or OPENAI_DATASET_STD\n    if not isinstance(std, (list, tuple)):\n        std = (std,) * 3\n\n    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:\n        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge\n        image_size = image_size[0]\n\n    if isinstance(aug_cfg, dict):\n        aug_cfg = AugmentationCfg(**aug_cfg)\n    else:\n        aug_cfg = aug_cfg or AugmentationCfg()\n    normalize = Normalize(mean=mean, std=std)\n    if is_train:\n        aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}\n        use_timm = aug_cfg_dict.pop('use_timm', False)\n        if use_timm:\n            from timm.data import create_transform  # timm can still be optional\n            if isinstance(image_size, (tuple, list)):\n                assert len(image_size) >= 2\n                input_size = (3,) + image_size[-2:]\n            else:\n                input_size = (3, image_size, image_size)\n            # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time\n            aug_cfg_dict.setdefault('interpolation', 'random')\n            aug_cfg_dict.setdefault('color_jitter', None)  # disable by default\n            train_transform = create_transform(\n                input_size=input_size,\n                is_training=True,\n                hflip=0.,\n                mean=mean,\n                std=std,\n                re_mode='pixel',\n                **aug_cfg_dict,\n            )\n        else:\n            train_transform = Compose([\n                RandomResizedCrop(\n                    image_size,\n                    scale=aug_cfg_dict.pop('scale'),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                _convert_to_rgb,\n                ToTensor(),\n                normalize,\n            ])\n            if aug_cfg_dict:\n                warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')\n        return train_transform\n    else:\n        if resize_longest_max:\n            transforms = [\n                ResizeMaxSize(image_size, fill=fill_color)\n            ]\n        else:\n            transforms = [\n                Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),\n                CenterCrop((image_size, image_size)),\n            ]\n        transforms.extend([\n            _convert_to_rgb,\n            ToTensor(),\n            normalize,\n        ])\n        return Compose(transforms)\n"
  },
  {
    "path": "open_clip/transformer.py",
    "content": "from collections import OrderedDict\nimport math\nfrom typing import Callable, Optional, Sequence, Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils.checkpoint import checkpoint\n\nfrom .utils import to_2tuple\nimport numpy as np\n\n\nclass LayerNormFp32(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)\n        return x.to(orig_type)\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm (with cast back to input dtype).\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        return x.to(orig_type)\n\n\nclass QuickGELU(nn.Module):\n    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim, init_values=1e-5, inplace=False):\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x):\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n\n\nclass PatchDropout(nn.Module):\n    \"\"\"\n    https://arxiv.org/abs/2212.00794\n    \"\"\"\n\n    def __init__(self, prob, exclude_first_token=True):\n        super().__init__()\n        assert 0 <= prob < 1.\n        self.prob = prob\n        self.exclude_first_token = exclude_first_token  # exclude CLS token\n\n    def forward(self, x):\n        if not self.training or self.prob == 0.:\n            return x\n\n        if self.exclude_first_token:\n            cls_tokens, x = x[:, :1], x[:, 1:]\n        else:\n            cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])\n\n        batch = x.size()[0]\n        num_tokens = x.size()[1]\n\n        batch_indices = torch.arange(batch)\n        batch_indices = batch_indices[..., None]\n\n        keep_prob = 1 - self.prob\n        num_patches_keep = max(1, int(num_tokens * keep_prob))\n\n        rand = torch.randn(batch, num_tokens)\n        patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices\n\n        x = x[batch_indices, patch_indices_keep]\n\n        if self.exclude_first_token:\n            x = torch.cat((cls_tokens, x), dim=1)\n\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            num_heads=8,\n            qkv_bias=True,\n            scaled_cosine=False,\n            scale_heads=False,\n            logit_scale_max=math.log(1. / 0.01),\n            attn_drop=0.,\n            proj_drop=0.\n    ):\n        super().__init__()\n        self.scaled_cosine = scaled_cosine\n        self.scale_heads = scale_heads\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n        self.logit_scale_max = logit_scale_max\n\n        # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original\n        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)\n        if qkv_bias:\n            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))\n        else:\n            self.in_proj_bias = None\n\n        if self.scaled_cosine:\n            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\n        else:\n            self.logit_scale = None\n        self.attn_drop = nn.Dropout(attn_drop)\n        if self.scale_heads:\n            self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))\n        else:\n            self.head_scale = None\n        self.out_proj = nn.Linear(dim, dim)\n        self.out_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):\n        L, N, C = x.shape\n        q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)\n        q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)\n        k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)\n        v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)\n\n        if self.logit_scale is not None:\n            attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))\n            logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()\n            attn = attn.view(N, self.num_heads, L, L) * logit_scale\n            attn = attn.view(-1, L, L)\n        else:\n            q = q * self.scale\n            attn = torch.bmm(q, k.transpose(-1, -2))\n\n        if attn_mask is not None:\n            if attn_mask.dtype == torch.bool:\n                new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)\n                new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\n                attn_mask = new_attn_mask\n            attn += attn_mask\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = torch.bmm(attn, v)\n        if self.head_scale is not None:\n            x = x.view(N, self.num_heads, L, C) * self.head_scale\n            x = x.view(-1, L, C)\n        x = x.transpose(0, 1).reshape(L, N, C)\n        x = self.out_proj(x)\n        x = self.out_drop(x)\n        return x\n\n\nclass AttentionalPooler(nn.Module):\n    def __init__(\n            self,\n            d_model: int,\n            context_dim: int,\n            n_head: int = 8,\n            n_queries: int = 256,\n            norm_layer: Callable = LayerNorm\n    ):\n        super().__init__()\n        self.query = nn.Parameter(torch.randn(n_queries, d_model))\n        self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)\n        self.ln_q = norm_layer(d_model)\n        self.ln_k = norm_layer(context_dim)\n\n    def forward(self, x: torch.Tensor):\n        x = self.ln_k(x).permute(1, 0, 2)  # NLD -> LND\n        N = x.shape[1]\n        q = self.ln_q(self.query)\n        out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]\n        return out.permute(1, 0, 2)  # LND -> NLD\n\n    def _repeat(self, query, N: int):\n        return query.unsqueeze(1).repeat(1, N, 1)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(\n            self,\n            d_model: int,\n            n_head: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            is_cross_attention: bool = False,\n            idx: int = 12,\n    ):\n        super().__init__()\n\n        self.idx = idx\n\n        self.ln_1 = norm_layer(d_model)\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n        if is_cross_attention:\n            self.ln_1_kv = norm_layer(d_model)\n\n        self.ln_2 = norm_layer(d_model)\n        mlp_width = int(d_model * mlp_ratio)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, mlp_width)),\n            (\"gelu\", act_layer()),\n            (\"c_proj\", nn.Linear(mlp_width, d_model))\n        ]))\n        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n\n    def attention(\n            self,\n            q_x: torch.Tensor,\n            k_x: Optional[torch.Tensor] = None,\n            v_x: Optional[torch.Tensor] = None,\n            attn_mask: Optional[torch.Tensor] = None,\n    ):\n        k_x = k_x if k_x is not None else q_x\n        v_x = v_x if v_x is not None else q_x\n\n        attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None\n        return self.attn(\n            q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask\n        )\n\n    def forward(\n            self,\n            q_x: torch.Tensor,\n            k_x: Optional[torch.Tensor] = None,\n            v_x: Optional[torch.Tensor] = None,\n            attn_mask: Optional[torch.Tensor] = None,\n    ):\n        k_x = self.ln_1_kv(k_x) if hasattr(self, \"ln_1_kv\") and k_x is not None else None\n        v_x = self.ln_1_kv(v_x) if hasattr(self, \"ln_1_kv\") and v_x is not None else None\n\n        tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)\n        x = q_x + self.ls_1(tmp)\n        x = x + self.ls_2(self.mlp(self.ln_2(x)))\n        return x, attn\n\n\nclass CustomResidualAttentionBlock(nn.Module):\n    def __init__(\n            self,\n            d_model: int,\n            n_head: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            scale_cosine_attn: bool = False,\n            scale_heads: bool = False,\n            scale_attn: bool = False,\n            scale_fc: bool = False,\n    ):\n        super().__init__()\n\n        self.ln_1 = norm_layer(d_model)\n        self.attn = Attention(\n            d_model, n_head,\n            scaled_cosine=scale_cosine_attn,\n            scale_heads=scale_heads,\n        )\n        self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()\n        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n\n        self.ln_2 = norm_layer(d_model)\n        mlp_width = int(d_model * mlp_ratio)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, mlp_width)),\n            ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),\n            (\"gelu\", act_layer()),\n            (\"c_proj\", nn.Linear(mlp_width, d_model))\n        ]))\n        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))\n        x = x + self.ls_2(self.mlp(self.ln_2(x)))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(\n            self,\n            width: int,\n            layers: int,\n            heads: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n    ):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.grad_checkpointing = False\n\n        self.resblocks = nn.ModuleList([\n            ResidualAttentionBlock(\n                width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer,\n                idx=idx)\n            for idx in range(layers)\n        ])\n\n    def get_cast_dtype(self) -> torch.dtype:\n        return self.resblocks[0].mlp.c_fc.weight.dtype\n\n    def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9],\n                attn_mask: Optional[torch.Tensor] = None):\n        idx = 0\n        out_attn = []\n        # out_tokens = x\n        out_tokens = []\n        for r in self.resblocks:\n            idx += 1\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n                x = checkpoint(r, x, None, None, attn_mask)\n            else:\n                if idx == 12:\n                    x, attn = r(x, attn_mask=attn_mask)\n                    out_attn.append(attn)\n                else:\n                    x, attn_tmp = r(x, attn_mask=attn_mask)\n                if idx in out_layers:\n                    out_tokens.append(x)\n                    # out_tokens = x\n        return x, out_attn, out_tokens\n\n\nclass VisionTransformer(nn.Module):\n    output_tokens: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            image_size: int,\n            patch_size: int,\n            width: int,\n            layers: int,\n            heads: int,\n            mlp_ratio: float,\n            ls_init_value: float = None,\n            global_average_pool: bool = False,\n            attentional_pool: bool = False,\n            n_queries: int = 256,\n            attn_pooler_heads: int = 8,\n            output_dim: int = 512,\n            patch_dropout: float = 0.,\n            input_patchnorm: bool = False,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            output_tokens: bool = False\n    ):\n        super().__init__()\n        self.output_tokens = output_tokens\n        image_height, image_width = self.image_size = to_2tuple(image_size)\n        patch_height, patch_width = self.patch_size = to_2tuple(patch_size)\n        self.grid_size = (image_height // patch_height, image_width // patch_width)\n        self.output_dim = output_dim\n\n        # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1\n        self.input_patchnorm = input_patchnorm\n\n        if input_patchnorm:\n            patch_input_dim = patch_height * patch_width * 3\n            self.patchnorm_pre_ln = LayerNorm(patch_input_dim)\n            self.conv1 = nn.Linear(patch_input_dim, width)\n        else:\n            self.patchnorm_pre_ln = nn.Identity()\n            self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size,\n                                   bias=False)\n\n        # class embeddings and positional embeddings\n        scale = width ** -0.5\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\n        self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))\n\n        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\n        self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()\n\n        self.ln_pre = norm_layer(width)\n        self.transformer = Transformer(\n            width,\n            layers,\n            heads,\n            mlp_ratio,\n            ls_init_value=ls_init_value,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n\n        self.global_average_pool = global_average_pool\n        if attentional_pool:\n            self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)\n            self.ln_post = norm_layer(output_dim)\n            self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))\n        else:\n            self.attn_pool = None\n            self.ln_post = norm_layer(width)\n            self.proj = nn.Parameter(scale * torch.randn(width, output_dim))\n\n        self.init_parameters()\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        for param in self.parameters():\n            param.requires_grad = False\n\n        if unlocked_groups != 0:\n            groups = [\n                [\n                    self.conv1,\n                    self.class_embedding,\n                    self.positional_embedding,\n                    self.ln_pre,\n                ],\n                *self.transformer.resblocks[:-1],\n                [\n                    self.transformer.resblocks[-1],\n                    self.ln_post,\n                ],\n                self.proj,\n            ]\n\n            def _unlock(x):\n                if isinstance(x, Sequence):\n                    for g in x:\n                        _unlock(g)\n                else:\n                    if isinstance(x, torch.nn.Parameter):\n                        x.requires_grad = True\n                    else:\n                        for p in x.parameters():\n                            p.requires_grad = True\n\n            _unlock(groups[-unlocked_groups:])\n\n    def init_parameters(self):\n        # FIXME OpenAI CLIP did not define an init for the VisualTransformer\n        # TODO experiment if default PyTorch init, below, or alternate init is best.\n\n        # nn.init.normal_(self.class_embedding, std=self.scale)\n        # nn.init.normal_(self.positional_embedding, std=self.scale)\n        #\n        # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        # attn_std = self.transformer.width ** -0.5\n        # fc_std = (2 * self.transformer.width) ** -0.5\n        # for block in self.transformer.resblocks:\n        #     nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n        #     nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n        #     nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n        #     nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n        #\n        # if self.text_projection is not None:\n        #     nn.init.normal_(self.text_projection, std=self.scale)\n        pass\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.transformer.grad_checkpointing = enable\n\n    def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        if self.global_average_pool:\n            return x.mean(dim=1), x\n        else:\n            return x[:, 0], x[:, 1:]\n\n    def forward(self, x: torch.Tensor, out_layers: list):\n\n        # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1\n        if self.input_patchnorm:\n            # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')\n            x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1],\n                          self.patch_size[1])\n            x = x.permute(0, 2, 4, 1, 3, 5)\n            x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)\n            x = self.patchnorm_pre_ln(x)\n            x = self.conv1(x)\n        else:\n            x = self.conv1(x)  # shape = [*, width, grid, grid]\n            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n\n        # class embeddings and positional embeddings\n        x = torch.cat(\n            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),\n             x], dim=1)  # shape = [*, grid ** 2 + 1, width]\n        x = x + self.positional_embedding.to(x.dtype)\n\n        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in\n        x = self.patch_dropout(x)\n        x = self.ln_pre(x)\n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x, attn, patch_tokens = self.transformer(x, out_layers)\n        # attn = attn[0, 0, 1:].view(14, 14)  # 49\n        B, C, L = attn[0].shape\n        H = int(np.sqrt(L-1))\n        out_attn = torch.zeros([H, H]).to('cuda')\n        for i in range(len(attn)):\n            out_attn += attn[i][0, 0, 1:].view(H, H)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))]  # LND -> NLD\n        # patch_tokens = patch_tokens.permute(1, 0, 2)  # LND -> NLD\n\n        if self.attn_pool is not None:\n            x = self.attn_pool(x)\n            x = self.ln_post(x)\n            pooled, tokens = self._global_pool(x)\n        else:\n            pooled, tokens = self._global_pool(x)\n            pooled = self.ln_post(pooled)\n            # patch_pooled, patch_tokens = self._global_pool(patch_tokens)\n            # tokens = self.ln_post(tokens)\n\n        if self.proj is not None:\n            pooled = pooled @ self.proj\n            # patch_tokens = patch_tokens @ self.proj  # 不知道能不能行\n            # tokens = tokens @ self.proj\n\n        if self.output_tokens:\n            return pooled, patch_tokens\n\n        return pooled, patch_tokens\n\n\nclass TextTransformer(nn.Module):\n    output_tokens: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            context_length: int = 77,\n            vocab_size: int = 49408,\n            width: int = 512,\n            heads: int = 8,\n            layers: int = 12,\n            ls_init_value: float = None,\n            output_dim: int = 512,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            embed_cls: bool = False,\n            pad_id: int = 0,\n            output_tokens: bool = False,\n    ):\n        super().__init__()\n        self.output_tokens = output_tokens\n        self.num_pos = self.context_length = context_length\n        self.vocab_size = vocab_size\n        self.width = width\n        self.output_dim = output_dim\n        self.heads = heads\n        self.pad_id = pad_id\n\n        self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n\n        if embed_cls:\n            self.cls_emb = nn.Parameter(torch.empty(width))\n            self.num_pos += 1\n        else:\n            self.cls_emb = None\n\n        self.token_embedding = nn.Embedding(vocab_size, width)\n        self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))\n        self.transformer = Transformer(\n            width=width,\n            layers=layers,\n            heads=heads,\n            ls_init_value=ls_init_value,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n        self.ln_final = norm_layer(width)\n\n        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)\n\n        self.init_parameters()\n\n    def init_parameters(self):\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.positional_embedding, std=0.01)\n        if self.cls_emb is not None:\n            nn.init.normal_(self.cls_emb, std=0.01)\n\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        attn_std = self.transformer.width ** -0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.transformer.grad_checkpointing = enable\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.num_pos, self.num_pos)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def build_cls_mask(self, text, cast_dtype: torch.dtype):\n        cls_mask = (text != self.pad_id).unsqueeze(1)\n        cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)\n        additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)\n        additive_mask.fill_(0)\n        additive_mask.masked_fill_(~cls_mask, float(\"-inf\"))\n        additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)\n        return additive_mask\n\n    def _repeat(self, t, N: int):\n        return t.reshape(1, 1, -1).repeat(N, 1, 1)\n\n    def forward(self, text):\n        cast_dtype = self.transformer.get_cast_dtype()\n        seq_len = text.shape[1]\n\n        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]\n        attn_mask = self.attn_mask\n        if self.cls_emb is not None:\n            seq_len += 1\n            x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)\n            cls_mask = self.build_cls_mask(text, cast_dtype)\n            attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]\n\n        x = x + self.positional_embedding[:seq_len].to(cast_dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        if self.cls_emb is not None:\n            pooled, tokens = x[:, -1], x[:, :-1]\n            pooled = self.ln_final(pooled)\n        else:\n            x = self.ln_final(x)\n            pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x\n\n        if self.text_projection is not None:\n            pooled = pooled @ self.text_projection\n\n        if self.output_tokens:\n            return pooled, tokens\n\n        return pooled\n\n\nclass MultimodalTransformer(Transformer):\n    def __init__(\n            self,\n            width: int,\n            layers: int,\n            heads: int,\n            context_length: int = 77,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            output_dim: int = 512,\n    ):\n\n        super().__init__(\n            width=width,\n            layers=layers,\n            heads=heads,\n            mlp_ratio=mlp_ratio,\n            ls_init_value=ls_init_value,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n        self.context_length = context_length\n        self.cross_attn = nn.ModuleList([\n            ResidualAttentionBlock(\n                width,\n                heads,\n                mlp_ratio,\n                ls_init_value=ls_init_value,\n                act_layer=act_layer,\n                norm_layer=norm_layer,\n                is_cross_attention=True,\n            )\n            for _ in range(layers)\n        ])\n\n        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)\n\n        self.ln_final = norm_layer(width)\n        self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n\n    def init_parameters(self):\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        attn_std = self.transformer.width ** -0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n        for block in self.transformer.cross_attn:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def forward(self, image_embs, text_embs):\n        text_embs = text_embs.permute(1, 0, 2)  # NLD -> LNDsq\n        image_embs = image_embs.permute(1, 0, 2)  # NLD -> LND\n        seq_len = text_embs.shape[0]\n\n        for resblock, cross_attn in zip(self.resblocks, self.cross_attn):\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n                text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])\n                text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)\n            else:\n                text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])\n                text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)\n\n        x = text_embs.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x)\n\n        if self.text_projection is not None:\n            x = x @ self.text_projection\n\n        return x\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n"
  },
  {
    "path": "open_clip/utils.py",
    "content": "from itertools import repeat\nimport collections.abc\n\nfrom torch import nn as nn\nfrom torchvision.ops.misc import FrozenBatchNorm2d\n\n\ndef freeze_batch_norm_2d(module, module_match={}, name=''):\n    \"\"\"\n    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is\n    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and\n    returned. Otherwise, the module is walked recursively and submodules are converted in place.\n\n    Args:\n        module (torch.nn.Module): Any PyTorch module.\n        module_match (dict): Dictionary of full module names to freeze (all if empty)\n        name (str): Full module name (prefix)\n\n    Returns:\n        torch.nn.Module: Resulting module\n\n    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762\n    \"\"\"\n    res = module\n    is_match = True\n    if module_match:\n        is_match = name in module_match\n    if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):\n        res = FrozenBatchNorm2d(module.num_features)\n        res.num_features = module.num_features\n        res.affine = module.affine\n        if module.affine:\n            res.weight.data = module.weight.data.clone().detach()\n            res.bias.data = module.bias.data.clone().detach()\n        res.running_mean.data = module.running_mean.data\n        res.running_var.data = module.running_var.data\n        res.eps = module.eps\n    else:\n        for child_name, child in module.named_children():\n            full_child_name = '.'.join([name, child_name]) if name else child_name\n            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)\n            if new_child is not child:\n                res.add_module(child_name, new_child)\n    return res\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable):\n            return x\n        return tuple(repeat(x, n))\n    return parse\n\n\nto_1tuple = _ntuple(1)\nto_2tuple = _ntuple(2)\nto_3tuple = _ntuple(3)\nto_4tuple = _ntuple(4)\nto_ntuple = lambda n, x: _ntuple(n)(x)\n"
  },
  {
    "path": "open_clip/version.py",
    "content": "__version__ = '2.16.0'\n"
  },
  {
    "path": "prompt_ensemble.py",
    "content": "import os\r\nfrom typing import Union, List\r\nfrom pkg_resources import packaging\r\nimport torch\r\nimport numpy as np\r\n\r\n\r\ndef encode_text_with_prompt_ensemble(model, objs, tokenizer, device):\r\n    prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']\r\n    prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']\r\n    prompt_state = [prompt_normal, prompt_abnormal]\r\n    prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']\r\n\r\n    text_prompts = {}\r\n    for obj in objs:\r\n        text_features = []\r\n        for i in range(len(prompt_state)):\r\n            prompted_state = [state.format(obj) for state in prompt_state[i]]\r\n            prompted_sentence = []\r\n            for s in prompted_state:\r\n                for template in prompt_templates:\r\n                    prompted_sentence.append(template.format(s))\r\n            prompted_sentence = tokenizer(prompted_sentence).to(device)\r\n            class_embeddings = model.encode_text(prompted_sentence)\r\n            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)\r\n            class_embedding = class_embeddings.mean(dim=0)\r\n            class_embedding /= class_embedding.norm()\r\n            text_features.append(class_embedding)\r\n\r\n        text_features = torch.stack(text_features, dim=1).to(device)\r\n        text_prompts[obj] = text_features\r\n\r\n    return text_prompts"
  },
  {
    "path": "requirements.txt",
    "content": "ftfy==6.1.1\nhorovod==0.28.1\nhuggingface_hub==0.13.4\nnumpy==1.21.6\nopencv_python==4.6.0.66\npandas==1.3.5\nPillow==9.2.0\nregex==2022.10.31\nscikit_image==0.19.3\nscikit_learn==1.0.2\nsetuptools==63.4.1\ntabulate==0.9.0\ntimm==0.8.15.dev0\ntorch==1.12.1\ntorchvision==0.13.1\ntqdm==4.64.1\ntransformers==4.15.0\n"
  },
  {
    "path": "test.py",
    "content": "import os\r\nimport cv2\r\nimport json\r\nimport torch\r\nimport random\r\nimport logging\r\nimport argparse\r\nimport numpy as np\r\nfrom PIL import Image\r\nfrom skimage import measure\r\nfrom tabulate import tabulate\r\nimport torch.nn.functional as F\r\nimport torchvision.transforms as transforms\r\nfrom sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise\r\n\r\nimport open_clip\r\nfrom few_shot import memory\r\nfrom model import LinearLayer\r\nfrom dataset import VisaDataset, MVTecDataset\r\nfrom prompt_ensemble import encode_text_with_prompt_ensemble\r\n\r\n\r\ndef setup_seed(seed):\r\n    torch.manual_seed(seed)\r\n    torch.cuda.manual_seed_all(seed)\r\n    np.random.seed(seed)\r\n    random.seed(seed)\r\n    torch.backends.cudnn.deterministic = True\r\n    torch.backends.cudnn.benchmark = False\r\n\r\n\r\ndef normalize(pred, max_value=None, min_value=None):\r\n    if max_value is None or min_value is None:\r\n        return (pred - pred.min()) / (pred.max() - pred.min())\r\n    else:\r\n        return (pred - min_value) / (max_value - min_value)\r\n\r\n\r\ndef apply_ad_scoremap(image, scoremap, alpha=0.5):\r\n    np_image = np.asarray(image, dtype=float)\r\n    scoremap = (scoremap * 255).astype(np.uint8)\r\n    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)\r\n    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)\r\n    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)\r\n\r\n\r\ndef cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):\r\n    # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py\r\n    binary_amaps = np.zeros_like(amaps, dtype=bool)\r\n    min_th, max_th = amaps.min(), amaps.max()\r\n    delta = (max_th - min_th) / max_step\r\n    pros, fprs, ths = [], [], []\r\n    for th in np.arange(min_th, max_th, delta):\r\n        binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1\r\n        pro = []\r\n        for binary_amap, mask in zip(binary_amaps, masks):\r\n            for region in measure.regionprops(measure.label(mask)):\r\n                tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()\r\n                pro.append(tp_pixels / region.area)\r\n        inverse_masks = 1 - masks\r\n        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()\r\n        fpr = fp_pixels / inverse_masks.sum()\r\n        pros.append(np.array(pro).mean())\r\n        fprs.append(fpr)\r\n        ths.append(th)\r\n    pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)\r\n    idxes = fprs < expect_fpr\r\n    fprs = fprs[idxes]\r\n    fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())\r\n    pro_auc = auc(fprs, pros[idxes])\r\n    return pro_auc\r\n\r\n\r\ndef test(args):\r\n    img_size = args.image_size\r\n    features_list = args.features_list\r\n    few_shot_features = args.few_shot_features\r\n    dataset_dir = args.data_path\r\n    save_path = args.save_path\r\n    dataset_name = args.dataset\r\n    if not os.path.exists(save_path):\r\n        os.makedirs(save_path)\r\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\r\n    txt_path = os.path.join(save_path, 'log.txt')\r\n\r\n    # clip\r\n    model, _, preprocess = open_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained)\r\n    model.to(device)\r\n    tokenizer = open_clip.get_tokenizer(args.model)\r\n\r\n    # logger\r\n    root_logger = logging.getLogger()\r\n    for handler in root_logger.handlers[:]:\r\n        root_logger.removeHandler(handler)\r\n    root_logger.setLevel(logging.WARNING)\r\n    logger = logging.getLogger('test')\r\n    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',\r\n                                  datefmt='%y-%m-%d %H:%M:%S')\r\n    logger.setLevel(logging.INFO)\r\n    file_handler = logging.FileHandler(txt_path, mode='w')\r\n    file_handler.setFormatter(formatter)\r\n    logger.addHandler(file_handler)\r\n    console_handler = logging.StreamHandler()\r\n    console_handler.setFormatter(formatter)\r\n    logger.addHandler(console_handler)\r\n\r\n    # record parameters\r\n    for arg in vars(args):\r\n        if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'):\r\n            continue\r\n        logger.info(f'{arg}: {getattr(args, arg)}')\r\n\r\n    # seg\r\n    with open(args.config_path, 'r') as f:\r\n        model_configs = json.load(f)\r\n    linearlayer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],\r\n                              len(features_list), args.model).to(device)\r\n    checkpoint = torch.load(args.checkpoint_path)\r\n    linearlayer.load_state_dict(checkpoint[\"trainable_linearlayer\"])\r\n\r\n    # dataset\r\n    transform = transforms.Compose([\r\n            transforms.Resize((img_size, img_size)),\r\n            transforms.CenterCrop(img_size),\r\n            transforms.ToTensor()\r\n        ])\r\n    if dataset_name == 'mvtec':\r\n        test_data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform,\r\n                                 aug_rate=-1, mode='test')\r\n    else:\r\n        test_data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform, mode='test')\r\n    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)\r\n    obj_list = test_data.get_cls_names()\r\n\r\n    # few shot\r\n    if args.mode == 'few_shot':\r\n        mem_features = memory(args.model, model, obj_list, dataset_dir, save_path, preprocess, transform,\r\n                              args.k_shot, few_shot_features, dataset_name, device)\r\n\r\n    # text prompt\r\n    with torch.cuda.amp.autocast(), torch.no_grad():\r\n        text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device)\r\n\r\n    results = {}\r\n    results['cls_names'] = []\r\n    results['imgs_masks'] = []\r\n    results['anomaly_maps'] = []\r\n    results['gt_sp'] = []\r\n    results['pr_sp'] = []\r\n    for items in test_dataloader:\r\n        image = items['img'].to(device)\r\n        cls_name = items['cls_name']\r\n        results['cls_names'].append(cls_name[0])\r\n        gt_mask = items['img_mask']\r\n        gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0\r\n        results['imgs_masks'].append(gt_mask)  # px\r\n        results['gt_sp'].append(items['anomaly'].item())\r\n\r\n        with torch.no_grad(), torch.cuda.amp.autocast():\r\n            image_features, patch_tokens = model.encode_image(image, features_list)\r\n            image_features /= image_features.norm(dim=-1, keepdim=True)\r\n            text_features = []\r\n            for cls in cls_name:\r\n                text_features.append(text_prompts[cls])\r\n            text_features = torch.stack(text_features, dim=0)\r\n\r\n            # sample\r\n            text_probs = (100.0 * image_features @ text_features[0]).softmax(dim=-1)\r\n            results['pr_sp'].append(text_probs[0][1].cpu().item())\r\n\r\n            # pixel\r\n            patch_tokens = linearlayer(patch_tokens)\r\n            anomaly_maps = []\r\n            for layer in range(len(patch_tokens)):\r\n                patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True)\r\n                anomaly_map = (100.0 * patch_tokens[layer] @ text_features)\r\n                B, L, C = anomaly_map.shape\r\n                H = int(np.sqrt(L))\r\n                anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),\r\n                                            size=img_size, mode='bilinear', align_corners=True)\r\n                anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1, :, :]\r\n                anomaly_maps.append(anomaly_map.cpu().numpy())\r\n            anomaly_map = np.sum(anomaly_maps, axis=0)\r\n\r\n            # few shot\r\n            if args.mode == 'few_shot':\r\n                image_features, patch_tokens = model.encode_image(image, few_shot_features)\r\n                anomaly_maps_few_shot = []\r\n                for idx, p in enumerate(patch_tokens):\r\n                    if 'ViT' in args.model:\r\n                        p = p[0, 1:, :]\r\n                    else:\r\n                        p = p[0].view(p.shape[1], -1).permute(1, 0).contiguous()\r\n                    cos = pairwise.cosine_similarity(mem_features[cls_name[0]][idx].cpu(), p.cpu())\r\n                    height = int(np.sqrt(cos.shape[1]))\r\n                    anomaly_map_few_shot = np.min((1 - cos), 0).reshape(1, 1, height, height)\r\n                    anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),\r\n                                                         size=img_size, mode='bilinear', align_corners=True)\r\n                    anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())\r\n                anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)\r\n                anomaly_map = anomaly_map + anomaly_map_few_shot\r\n            \r\n            results['anomaly_maps'].append(anomaly_map)\r\n\r\n            # visualization\r\n            path = items['img_path']\r\n            cls = path[0].split('/')[-2]\r\n            filename = path[0].split('/')[-1]\r\n            vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB)  # RGB\r\n            mask = normalize(anomaly_map[0])\r\n            vis = apply_ad_scoremap(vis, mask)\r\n            vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)  # BGR\r\n            save_vis = os.path.join(save_path, 'imgs', cls_name[0], cls)\r\n            if not os.path.exists(save_vis):\r\n                os.makedirs(save_vis)\r\n            cv2.imwrite(os.path.join(save_vis, filename), vis)\r\n\r\n    # metrics\r\n    table_ls = []\r\n    auroc_sp_ls = []\r\n    auroc_px_ls = []\r\n    f1_sp_ls = []\r\n    f1_px_ls = []\r\n    aupro_ls = []\r\n    ap_sp_ls = []\r\n    ap_px_ls = []\r\n    for obj in obj_list:\r\n        table = []\r\n        gt_px = []\r\n        pr_px = []\r\n        gt_sp = []\r\n        pr_sp = []\r\n        pr_sp_tmp = []\r\n        table.append(obj)\r\n        for idxes in range(len(results['cls_names'])):\r\n            if results['cls_names'][idxes] == obj:\r\n                gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy())\r\n                pr_px.append(results['anomaly_maps'][idxes])\r\n                pr_sp_tmp.append(np.max(results['anomaly_maps'][idxes]))\r\n                gt_sp.append(results['gt_sp'][idxes])\r\n                pr_sp.append(results['pr_sp'][idxes])\r\n        gt_px = np.array(gt_px)\r\n        gt_sp = np.array(gt_sp)\r\n        pr_px = np.array(pr_px)\r\n        pr_sp = np.array(pr_sp)\r\n        if args.mode == 'few_shot':\r\n            pr_sp_tmp = np.array(pr_sp_tmp)\r\n            pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min())\r\n            pr_sp = 0.5 * (pr_sp + pr_sp_tmp)\r\n\r\n        auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())\r\n        auroc_sp = roc_auc_score(gt_sp, pr_sp)\r\n        ap_sp = average_precision_score(gt_sp, pr_sp)\r\n        ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())\r\n        # f1_sp\r\n        precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)\r\n        f1_scores = (2 * precisions * recalls) / (precisions + recalls)\r\n        f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])\r\n        # f1_px\r\n        precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel())\r\n        f1_scores = (2 * precisions * recalls) / (precisions + recalls)\r\n        f1_px = np.max(f1_scores[np.isfinite(f1_scores)])\r\n        # aupro\r\n        if len(gt_px.shape) == 4:\r\n            gt_px = gt_px.squeeze(1)\r\n        if len(pr_px.shape) == 4:\r\n            pr_px = pr_px.squeeze(1)\r\n        aupro = cal_pro_score(gt_px, pr_px)\r\n\r\n        table.append(str(np.round(auroc_px * 100, decimals=1)))\r\n        table.append(str(np.round(f1_px * 100, decimals=1)))\r\n        table.append(str(np.round(ap_px * 100, decimals=1)))\r\n        table.append(str(np.round(aupro * 100, decimals=1)))\r\n        table.append(str(np.round(auroc_sp * 100, decimals=1)))\r\n        table.append(str(np.round(f1_sp * 100, decimals=1)))\r\n        table.append(str(np.round(ap_sp * 100, decimals=1)))\r\n\r\n        table_ls.append(table)\r\n        auroc_sp_ls.append(auroc_sp)\r\n        auroc_px_ls.append(auroc_px)\r\n        f1_sp_ls.append(f1_sp)\r\n        f1_px_ls.append(f1_px)\r\n        aupro_ls.append(aupro)\r\n        ap_sp_ls.append(ap_sp)\r\n        ap_px_ls.append(ap_px)\r\n\r\n    # logger\r\n    table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)),\r\n                     str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)),\r\n                     str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),\r\n                     str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])\r\n    results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', 'auroc_sp',\r\n                                          'f1_sp', 'ap_sp'], tablefmt=\"pipe\")\r\n    logger.info(\"\\n%s\", results)\r\n\r\n\r\nif __name__ == '__main__':\r\n    parser = argparse.ArgumentParser(\"VAND Challenge\", add_help=True)\r\n    # paths\r\n    parser.add_argument(\"--data_path\", type=str, default=\"./data/visa\", help=\"path to test dataset\")\r\n    parser.add_argument(\"--save_path\", type=str, default='./results/tiaoshi', help='path to save results')\r\n    parser.add_argument(\"--checkpoint_path\", type=str, default='./exps/vit_huge_14/model_epoch12.pth', help='path to save results')\r\n    parser.add_argument(\"--config_path\", type=str, default='./open_clip/model_configs/ViT-B-16.json', help=\"model configs\")\r\n    # model\r\n    parser.add_argument(\"--dataset\", type=str, default='mvtec', help=\"test dataset\")\r\n    parser.add_argument(\"--model\", type=str, default=\"ViT-B-16\", help=\"model used\")\r\n    parser.add_argument(\"--pretrained\", type=str, default=\"laion400m_e32\", help=\"pretrained weight used\")\r\n    parser.add_argument(\"--features_list\", type=int, nargs=\"+\", default=[3, 6, 9], help=\"features used\")\r\n    parser.add_argument(\"--few_shot_features\", type=int, nargs=\"+\", default=[3, 6, 9], help=\"features used for few shot\")\r\n    parser.add_argument(\"--image_size\", type=int, default=224, help=\"image size\")\r\n    parser.add_argument(\"--mode\", type=str, default=\"zero_shot\", help=\"zero shot or few shot\")\r\n    # few shot\r\n    parser.add_argument(\"--k_shot\", type=int, default=10, help=\"e.g., 10-shot, 5-shot, 1-shot\")\r\n    parser.add_argument(\"--seed\", type=int, default=10, help=\"random seed\")\r\n    args = parser.parse_args()\r\n\r\n    setup_seed(args.seed)\r\n    test(args)\r\n"
  },
  {
    "path": "test_few_shot.sh",
    "content": "### test on the VisA dataset\npython test.py --mode few_shot --dataset visa \\\n--data_path ./data/visa --save_path ./results/visa/few_shot/4shot/seed42 \\\n--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/mvtec_pretrained.pth \\\n--model ViT-L-14-336 --features_list 6 12 18 24 --few_shot_features 6 12 18 24 \\\n--pretrained openai --image_size 518 --k_shot 4 --seed 42\n\n\n### test on the MVTec AD dataset\npython test.py --mode few_shot --dataset mvtec \\\n--data_path ./data/mvtec --save_path ./results/mvtec/few_shot/4shot/seed42 \\\n--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/visa_pretrained.pth \\\n--model ViT-L-14-336 --features_list 6 12 18 24 --few_shot_features 6 12 18 24 \\\n--pretrained openai --image_size 518 --k_shot 4 --seed 42\n"
  },
  {
    "path": "test_zero_shot.sh",
    "content": "### test on the VisA dataset\npython test.py --mode zero_shot --dataset visa \\\n--data_path ./data/visa --save_path ./results/visa/zero_shot \\\n--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/mvtec_pretrained.pth \\\n--model ViT-L-14-336 --features_list 6 12 18 24 --pretrained openai --image_size 518\n\n### test on the MVTec AD dataset\npython test.py --mode zero_shot --dataset mvtec \\\n--data_path ./data/mvtec --save_path ./results/mvtec/zero_shot \\\n--config_path ./open_clip/model_configs/ViT-L-14-336.json --checkpoint_path ./exps/pretrained/visa_pretrained.pth \\\n--model ViT-L-14-336 --features_list 6 12 18 24 --pretrained openai --image_size 518\n\n\n"
  },
  {
    "path": "train.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport numpy as np\r\nimport random\r\nimport os\r\nimport json\r\nimport argparse\r\nfrom torch.utils.data import DataLoader\r\nfrom datetime import datetime\r\nfrom torch.nn import functional as F\r\nimport torch.backends.cudnn as cudnn\r\nimport torchvision.transforms as transforms\r\nimport logging\r\n\r\nimport open_clip\r\nfrom dataset import VisaDataset, MVTecDataset\r\nfrom model import LinearLayer\r\nfrom loss import FocalLoss, BinaryDiceLoss\r\nfrom prompt_ensemble import encode_text_with_prompt_ensemble\r\n\r\n\r\ndef setup_seed(seed):\r\n    torch.manual_seed(seed)\r\n    torch.cuda.manual_seed_all(seed)\r\n    np.random.seed(seed)\r\n    random.seed(seed)\r\n    torch.backends.cudnn.deterministic = True\r\n    torch.backends.cudnn.benchmark = False\r\n\r\n\r\ndef train(args):\r\n    # configs\r\n    epochs = args.epoch\r\n    learning_rate = args.learning_rate\r\n    batch_size = args.batch_size\r\n    image_size = args.image_size\r\n    device = 'cuda' if torch.cuda.is_available() else 'cpu'\r\n    save_path = args.save_path\r\n    if not os.path.exists(save_path):\r\n        os.makedirs(save_path)\r\n    txt_path = os.path.join(save_path, 'log.txt')  # log\r\n\r\n    # model configs\r\n    features_list = args.features_list\r\n    with open(args.config_path, 'r') as f:\r\n        model_configs = json.load(f)\r\n\r\n    # clip model\r\n    model, _, preprocess = open_clip.create_model_and_transforms(args.model, image_size, pretrained=args.pretrained)\r\n    model.to(device)\r\n    tokenizer = open_clip.get_tokenizer(args.model)\r\n\r\n    # logger\r\n    root_logger = logging.getLogger()\r\n    for handler in root_logger.handlers[:]:\r\n        root_logger.removeHandler(handler)\r\n    root_logger.setLevel(logging.WARNING)\r\n    logger = logging.getLogger('train')\r\n    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',\r\n                                  datefmt='%y-%m-%d %H:%M:%S')\r\n    logger.setLevel(logging.INFO)\r\n    file_handler = logging.FileHandler(txt_path, mode='w')\r\n    file_handler.setFormatter(formatter)\r\n    logger.addHandler(file_handler)\r\n    console_handler = logging.StreamHandler()\r\n    console_handler.setFormatter(formatter)\r\n    logger.addHandler(console_handler)\r\n\r\n    # record parameters\r\n    for arg in vars(args):\r\n        logger.info(f'{arg}: {getattr(args, arg)}')\r\n\r\n    # transforms\r\n    transform = transforms.Compose([\r\n        transforms.Resize((image_size, image_size)),\r\n        transforms.CenterCrop(image_size),\r\n        transforms.ToTensor()\r\n    ])\r\n    \r\n    # datasets\r\n    if args.dataset == 'mvtec':\r\n        train_data = MVTecDataset(root=args.train_data_path, transform=preprocess, target_transform=transform,\r\n                                  aug_rate=args.aug_rate)\r\n    else:\r\n        train_data = VisaDataset(root=args.train_data_path, transform=preprocess, target_transform=transform)\r\n    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\r\n\r\n    # linear layer\r\n    trainable_layer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],\r\n                                  len(args.features_list), args.model).to(device)\r\n\r\n    optimizer = torch.optim.Adam(list(trainable_layer.parameters()), lr=learning_rate, betas=(0.5, 0.999))\r\n\r\n    # losses\r\n    loss_focal = FocalLoss()\r\n    loss_dice = BinaryDiceLoss()\r\n\r\n    # text prompt\r\n    with torch.cuda.amp.autocast(), torch.no_grad():\r\n        obj_list = train_data.get_cls_names()\r\n        text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device)\r\n\r\n    for epoch in range(epochs):\r\n        loss_list = []\r\n        idx = 0\r\n        for items in train_dataloader:\r\n            idx += 1\r\n            image = items['img'].to(device)\r\n            cls_name = items['cls_name']\r\n            with torch.cuda.amp.autocast():\r\n                with torch.no_grad():\r\n                    image_features, patch_tokens = model.encode_image(image, features_list)\r\n                    text_features = []\r\n                    for cls in cls_name:\r\n                        text_features.append(text_prompts[cls])\r\n                    text_features = torch.stack(text_features, dim=0)\r\n\r\n                # pixel level\r\n                patch_tokens = trainable_layer(patch_tokens)\r\n                anomaly_maps = []\r\n                for layer in range(len(patch_tokens)):\r\n                    patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True)\r\n                    anomaly_map = (100.0 * patch_tokens[layer] @ text_features)\r\n                    B, L, C = anomaly_map.shape\r\n                    H = int(np.sqrt(L))\r\n                    anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),\r\n                                                size=image_size, mode='bilinear', align_corners=True)\r\n                    anomaly_map = torch.softmax(anomaly_map, dim=1)\r\n                    anomaly_maps.append(anomaly_map)\r\n\r\n            # losses\r\n            gt = items['img_mask'].squeeze().to(device)\r\n            gt[gt > 0.5], gt[gt <= 0.5] = 1, 0\r\n            loss = 0\r\n            for num in range(len(anomaly_maps)):\r\n                loss += loss_focal(anomaly_maps[num], gt)\r\n                loss += loss_dice(anomaly_maps[num][:, 1, :, :], gt)\r\n\r\n            optimizer.zero_grad()\r\n            loss.backward()\r\n            optimizer.step()\r\n            loss_list.append(loss.item())\r\n\r\n        # logs\r\n        if (epoch + 1) % args.print_freq == 0:\r\n            logger.info('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, np.mean(loss_list)))\r\n\r\n        # save model\r\n        if (epoch + 1) % args.save_freq == 0:\r\n            ckp_path = os.path.join(save_path, 'epoch_' + str(epoch + 1) + '.pth')\r\n            torch.save({'trainable_linearlayer': trainable_layer.state_dict()}, ckp_path)\r\n\r\n\r\nif __name__ == '__main__':\r\n    parser = argparse.ArgumentParser(\"VAND Challenge\", add_help=True)\r\n    # path\r\n    parser.add_argument(\"--train_data_path\", type=str, default=\"./data/visa\", help=\"train dataset path\")\r\n    parser.add_argument(\"--save_path\", type=str, default='./exps/vit_large_14_518', help='path to save results')\r\n    parser.add_argument(\"--config_path\", type=str, default='./open_clip/model_configs/ViT-B-16.json', help=\"model configs\")\r\n    # model\r\n    parser.add_argument(\"--dataset\", type=str, default='mvtec', help=\"train dataset name\")\r\n    parser.add_argument(\"--model\", type=str, default=\"ViT-B-16\", help=\"model used\")\r\n    parser.add_argument(\"--pretrained\", type=str, default=\"laion400m_e32\", help=\"pretrained weight used\")\r\n    parser.add_argument(\"--features_list\", type=int, nargs=\"+\", default=[3, 6, 9], help=\"features used\")\r\n    # hyper-parameter\r\n    parser.add_argument(\"--epoch\", type=int, default=200, help=\"epochs\")\r\n    parser.add_argument(\"--learning_rate\", type=float, default=0.001, help=\"learning rate\")\r\n    parser.add_argument(\"--batch_size\", type=int, default=16, help=\"batch size\")\r\n    parser.add_argument(\"--image_size\", type=int, default=224, help=\"image size\")\r\n    parser.add_argument(\"--aug_rate\", type=float, default=0.2, help=\"image size\")\r\n    parser.add_argument(\"--print_freq\", type=int, default=30, help=\"print frequency\")\r\n    parser.add_argument(\"--save_freq\", type=int, default=3, help=\"save frequency\")\r\n    args = parser.parse_args()\r\n\r\n    setup_seed(111)\r\n    train(args)\r\n\r\n"
  },
  {
    "path": "train.sh",
    "content": "### train on the MVTec AD dataset\npython train.py --dataset mvtec --train_data_path ./data/mvtec \\\n--save_path ./exps/visa/vit_large_14_518 --config_path ./open_clip/model_configs/ViT-L-14-336.json --model ViT-L-14-336 \\\n--features_list 6 12 18 24 --pretrained openai --image_size 518  --batch_size 8 --aug_rate 0.2 --print_freq 1 \\\n--epoch 3 --save_freq 1\n\n\n### train on the VisA dataset\npython train.py --dataset visa --train_data_path ./data/visa \\\n--save_path ./exps/mvtec/vit_large_14_518 --config_path ./open_clip/model_configs/ViT-L-14-336.json --model ViT-L-14-336 \\\n--features_list 6 12 18 24 --pretrained openai --image_size 518  --batch_size 8 --print_freq 1 \\\n--epoch 15 --save_freq 1\n\n"
  }
]