Repository: wangf3014/SCLIP Branch: main Commit: b56bd3014b9e Files: 30 Total size: 100.9 KB Directory structure: gitextract_zfhzwwwp/ ├── README.md ├── clip/ │ ├── __init__.py │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── clip_segmentor.py ├── configs/ │ ├── base_config.py │ ├── cfg_ade20k.py │ ├── cfg_city_scapes.py │ ├── cfg_coco_object.py │ ├── cfg_coco_stuff10k.py │ ├── cfg_coco_stuff164k.py │ ├── cfg_context59.py │ ├── cfg_context60.py │ ├── cfg_voc20.py │ ├── cfg_voc21.py │ ├── cls_ade20k.txt │ ├── cls_city_scapes.txt │ ├── cls_coco_object.txt │ ├── cls_coco_stuff.txt │ ├── cls_context59.txt │ ├── cls_context60.txt │ ├── cls_voc20.txt │ └── cls_voc21.txt ├── custom_datasets.py ├── datasets/ │ └── cvt_coco_object.py ├── dist_test.sh ├── eval.py ├── pamr.py └── prompts/ └── imagenet_template.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # SCLIP: Rethinking Self-Attention for Dense Vision-Language Inference **News: this paper has been accepted by ECCV 2024** **Official PyTorch implementation of SCLIP** * [SCLIP: Rethinking Self-Attention for Dense Vision-Language Inference](https://arxiv.org/pdf/2312.01597.pdf). * A **simple** but very effective open-vocabulary semantic segmentation model derived from CLIP. * **SOTA** zero-shot segmentation results obtained by minimal modifications to CLIP's self-attention. **Model components and our Correlative Self-Attention maps:** ![sclip_0](figs/sclip_0.png) **Open-vocabulary semantic segmentation samples:** ![sclip_1](figs/sclip_1.png) ## Dependencies This repo is built on top of [CLIP](https://github.com/openai/CLIP) and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). To run SCLIP, please install the following packages with your Pytorch environment. We recommend using Pytorch==1.10.x for better compatibility to the following MMSeg version. ``` pip install openmim mim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1 pip install ftfy regex yapf==0.40.1 ``` ## Datasets We include the following dataset configurations in this repo: PASCAL VOC, PASCAL Context, Cityscapes, ADE20k, COCO-Stuff10k, and COCO-Stuff164k, with three more variant datasets VOC20, Context59 (i.e., PASCAL VOC and PASCAL Context without the background category), and COCO-Object. Please follow the [MMSeg data preparation document](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download and pre-process the datasets. The COCO-Object dataset can be converted from COCO-Stuff164k by executing the following command: ``` python datasets/cvt_coco_object.py PATH_TO_COCO_STUFF164K -o PATH_TO_COCO164K ``` **Remember to modify the dataset paths in the config files in** `config/cfg_DATASET.py` ## Run SCLIP Single-GPU running: ``` python eval.py --config ./configs/cfg_DATASET.py --workdir YOUR_WORK_DIR ``` Multi-GPU running: ``` bash ./dist_test.sh ./configs/cfg_DATASET.py ``` ## Results The performance of open-vocabulary inference can be affected by the text targets, i.e., the prompts and class names. This repo presents a easy way to explore them: you can modify prompts in `prompts/imagenet_template.py`, and class names in `configs/cls_DATASET.text`. The repo automatically loads class names from the `configs/cls_DATASET.text` file. The rule of class names is that each category can have multiple class names, and these class names share one line in the file, separated by commas. With the default setup in this repo, you should get the following results: | Dataset | mIoU | | --------------------- | ----- | | ADE20k | 16.45 | | Cityscapes | 32.34 | | COCO-Object | 33.52 | | COCO-Stuff10k | 25.91 | | COCO-Stuff164k | 22.77 | | PASCAL Context59 | 34.46 | | PASCAL Context60 | 31.74 | | PASCAL VOC (w/o. bg.) | 81.54 | | PASCAL VOC (w. bg.) | 59.63 | ## Citation ``` @article{wang2023sclip, title={SCLIP: Rethinking Self-Attention for Dense Vision-Language Inference}, author={Wang, Feng and Mei, Jieru and Yuille, Alan}, journal={arXiv preprint arXiv:2312.01597}, year={2023} } ``` ================================================ FILE: clip/__init__.py ================================================ from .clip import * from .model import * ================================================ FILE: clip/clip.py ================================================ ### CLIP source code from OpenAI: # https://github.com/openai/CLIP/blob/main/clip/clip.py import hashlib import os import urllib import warnings from typing import Any, Union, List from pkg_resources import packaging import torch from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): warnings.warn("PyTorch version 1.7.1 or higher is recommended") __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() _MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", } def _download(url: str, root: str): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def _convert_image_to_rgb(image): return image.convert("RGB") def _transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), _convert_image_to_rgb, ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def available_models() -> List[str]: """Returns the names of available CLIP models""" return list(_MODELS.keys()) def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model or more hackable non-JIT model (default). download_root: str path to download the model files; by default, it uses "~/.cache/clip" Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) elif os.path.isfile(name): model_path = name else: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() state_dict = None except RuntimeError: # loading saved state dict if jit: warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") jit = False state_dict = torch.load(model_path, map_location="cpu") if not jit: model = build_model(state_dict or model.state_dict()).to(device) if str(device) == "cpu": model.float() return model, _transform(model.visual.input_resolution) # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] def patch_device(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("prim::Constant"): if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_device) patch_device(model.encode_image) patch_device(model.encode_text) # patch dtype to float32 on CPU if str(device) == "cpu": float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() def patch_float(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [1, 2]: # dtype can be the second or third argument to aten::to() if inputs[i].node()["value"] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) patch_float(model.encode_image) patch_float(model.encode_text) model.float() return model, _transform(model.input_resolution.item()) def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result ================================================ FILE: clip/model.py ================================================ ### CLIP source code from OpenAI: # https://github.com/openai/CLIP/blob/main/clip/clip.py from collections import OrderedDict from typing import Tuple, Union import math import numpy as np import torch import torch.nn.functional as F from torch import nn import torchvision.transforms.functional as VF class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential(OrderedDict([ ("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion)) ])) def forward(self, x: torch.Tensor): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x, return_all_tokens=False): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) if return_all_tokens: return x else: return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): super().__init__() self.output_dim = output_dim self.input_resolution = input_resolution # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.avgpool = nn.AvgPool2d(2) self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def forward(self, x, return_all_tokens=False): def stem(x): for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x, return_all_tokens) return x class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None # pdb.set_trace() return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) def forward(self, x: torch.Tensor): return self.resblocks(x) class VisionTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): super().__init__() self.input_resolution = input_resolution self.patch_size = patch_size self.output_dim = output_dim self.width = width self.heads = heads self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) def forward(self, x: torch.Tensor, return_all=False, csa=True): B, nc, w, h = x.shape x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] if x.shape[1] != self.positional_embedding.shape[0]: x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) else: x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND for blk in self.transformer.resblocks[:-1]: x = blk(x) for blk in self.transformer.resblocks[-1:]: x = x + self.custom_attn(blk.attn, blk.ln_1(x), csa=csa) x = x + blk.mlp(blk.ln_2(x)) x = x.permute(1, 0, 2) # LND -> NLD if return_all: return self.ln_post(x) @ self.proj x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj return x def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.positional_embedding.shape[0] - 1 if npatch == N and w == h: return self.positional_embedding class_pos_embed = self.positional_embedding[[0]] patch_pos_embed = self.positional_embedding[1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def custom_attn(self, attn_layer, x, return_attn=False, with_attn=False, csa=False): num_heads = attn_layer.num_heads _, bsz, embed_dim = x.size() head_dim = embed_dim // num_heads scale = head_dim ** -0.5 q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if csa: q_attn = torch.bmm(q, q.transpose(1, 2)) * scale k_attn = torch.bmm(k, k.transpose(1, 2)) * scale attn_weights = F.softmax(q_attn, dim=-1) + F.softmax(k_attn, dim=-1) else: attn_weights = torch.bmm(q * scale, k.transpose(1, 2)) attn_weights = F.softmax(attn_weights, dim=-1) if return_attn: return attn_weights attn_output = torch.bmm(attn_weights, v) attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) attn_output = attn_layer.out_proj(attn_output) if with_attn: return attn_output, attn_weights return attn_output def get_attn(self, x, layer='all', csa=False): B, nc, w, h = x.shape x = self.conv1(x.type(self.conv1.weight.dtype)) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] if x.shape[1] != self.positional_embedding.shape[0]: x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) else: x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND if layer == 'final': for blk in self.transformer.resblocks[:-1]: x = blk(x) attn_map = self.custom_attn(self.transformer.resblocks[-1].attn, self.transformer.resblocks[-1].ln_1(x), csa=csa, return_attn=True) return attn_map elif layer == 'all': attn_map = [] for blk in self.transformer.resblocks[:-1]: x_i, attn_i = self.custom_attn(blk.attn, blk.ln_1(x), with_attn=True) x = x + x_i x = x + blk.mlp(blk.ln_2(x)) attn_map.append(attn_i) for blk in self.transformer.resblocks[-1:]: x_i, attn_i = self.custom_attn(blk.attn, blk.ln_1(x), with_attn=True, csa=True) x = x + x_i x = x + blk.mlp(blk.ln_2(x)) attn_map.append(attn_i) return attn_map else: raise ValueError('layer should be final or all') class CLIP(nn.Module): def __init__(self, embed_dim: int, # 512 # vision image_resolution: int, # 224 vision_layers: Union[Tuple[int, int, int, int], int], # 12 vision_width: int, # 768 vision_patch_size: int, # 16 # text context_length: int, # 77 vocab_size: int, # 49408 transformer_width: int, # 512 transformer_heads: int, # 8 transformer_layers: int # 12 ): super().__init__() self.context_length = context_length if isinstance(vision_layers, (tuple, list)): vision_heads = vision_width * 32 // 64 self.visual = ModifiedResNet( layers=vision_layers, output_dim=embed_dim, heads=vision_heads, input_resolution=image_resolution, width=vision_width ) else: vision_heads = vision_width // 64 self.visual = VisionTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim ) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask() ) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() def initialize_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if isinstance(self.visual, ModifiedResNet): if self.visual.attnpool is not None: std = self.visual.attnpool.c_proj.in_features ** -0.5 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image, return_all=False, csa=False): return self.visual(image.type(self.dtype), return_all=return_all, csa=csa) def encode_text(self, text): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) return x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection def forward(self, image, text): image_features = self.encode_image(image) text_features = self.encode_text(text) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16) def build_model(state_dict: dict): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_resolution = vision_patch_size * grid_size else: counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_resolution = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) model = CLIP( embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers ) for key in ["input_resolution", "context_length", "vocab_size"]: if key in state_dict: del state_dict[key] convert_weights(model) model.load_state_dict(state_dict) return model.eval() ================================================ FILE: clip/simple_tokenizer.py ================================================ ### CLIP source code from OpenAI: # https://github.com/openai/CLIP/blob/main/clip/clip.py import gzip import html import os from functools import lru_cache import ftfy import regex as re @lru_cache() def default_bpe(): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe()): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + ( token[-1] + '',) pairs = get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text ================================================ FILE: clip_segmentor.py ================================================ import torch import torch.nn as nn import sys sys.path.append("..") import clip from prompts.imagenet_template import openai_imagenet_template from mmseg.models.segmentors import BaseSegmentor from mmseg.models.data_preprocessor import SegDataPreProcessor from mmengine.structures import PixelData from mmseg.registry import MODELS from pamr import PAMR @MODELS.register_module() class CLIPForSegmentation(BaseSegmentor): def __init__(self, clip_path, name_path, device=torch.device('cuda'), pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40, slide_stride=112, slide_crop=224, area_thd=None): data_preprocessor = SegDataPreProcessor( mean=[122.771, 116.746, 104.094], std=[68.501, 66.632, 70.323], rgb_to_bgr=True) super().__init__(data_preprocessor=data_preprocessor) self.net, _ = clip.load(clip_path, device=device, jit=False) query_words, self.query_idx = get_cls_idx(name_path) self.num_queries = len(query_words) self.num_classes = max(self.query_idx) + 1 self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device) query_features = [] with torch.no_grad(): for qw in query_words: query = clip.tokenize([temp(qw) for temp in openai_imagenet_template]).to(device) feature = self.net.encode_text(query) feature /= feature.norm(dim=-1, keepdim=True) feature = feature.mean(dim=0) feature /= feature.norm() query_features.append(feature.unsqueeze(0)) self.query_features = torch.cat(query_features, dim=0) self.dtype = self.query_features.dtype self.logit_scale = logit_scale self.prob_thd = prob_thd self.area_thd = area_thd self.slide_stride = slide_stride self.slide_crop = slide_crop self.align_corners = False if pamr_steps > 0: self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device) else: self.pamr = None def forward_feature(self, img, logit_size=None): if type(img) == list: img = img[0] image_features = self.net.encode_image(img, return_all=True, csa=True) image_features /= image_features.norm(dim=-1, keepdim=True) image_features = image_features[:, 1:] logits = image_features @ self.query_features.T patch_size = self.net.visual.patch_size w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size out_dim = logits.shape[-1] logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h) if logit_size == None: logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear') else: logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear') return logits def forward_slide(self, img, img_metas, stride=112, crop_size=224): """Inference by sliding-window with overlap. If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding. """ if type(img) == list: img = img[0].unsqueeze(0) if type(stride) == int: stride = (stride, stride) if type(crop_size) == int: crop_size = (crop_size, crop_size) h_stride, w_stride = stride h_crop, w_crop = crop_size batch_size, _, h_img, w_img = img.shape out_channels = self.num_queries h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 preds = img.new_zeros((batch_size, out_channels, h_img, w_img)) count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] crop_seg_logit = self.forward_feature(crop_img) preds += nn.functional.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 preds = preds / count_mat img_size = img_metas[0]['ori_shape'][:2] logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear') if self.pamr: img = nn.functional.interpolate(img, size=img_size, mode='bilinear') logits = self.pamr(img, logits.to(img.dtype)).to(self.dtype) return logits def predict(self, inputs, data_samples): if data_samples is not None: batch_img_metas = [ data_sample.metainfo for data_sample in data_samples ] else: batch_img_metas = [ dict( ori_shape=inputs.shape[2:], img_shape=inputs.shape[2:], pad_shape=inputs.shape[2:], padding_size=[0, 0, 0, 0]) ] * inputs.shape[0] if self.slide_crop > 0: seg_logits = self.forward_slide(inputs, batch_img_metas, self.slide_stride, self.slide_crop) else: seg_logits = self.forward_feature(inputs, batch_img_metas[0]['ori_shape']) return self.postprocess_result(seg_logits, data_samples) def postprocess_result(self, seg_logits, data_samples): batch_size = seg_logits.shape[0] for i in range(batch_size): seg_logits = seg_logits[i] * self.logit_scale seg_logits = seg_logits.softmax(0) # n_queries * w * h num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx) if num_cls != num_queries: seg_logits = seg_logits.unsqueeze(0) cls_index = nn.functional.one_hot(self.query_idx) cls_index = cls_index.T.view(num_cls, num_queries, 1, 1) seg_logits = (seg_logits * cls_index).max(1)[0] seg_pred = seg_logits.argmax(0, keepdim=True) if self.area_thd is not None: # Force segmentations with area < self.area_thd to 0 (background) predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype) area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True) # prone background area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype) seg_logits[1:] *= area_pred.transpose(0, -1) seg_pred = seg_logits.argmax(0, keepdim=True) seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0 data_samples[i].set_data({ 'seg_logits': PixelData(**{'data': seg_logits}), 'pred_sem_seg': PixelData(**{'data': seg_pred}) }) return data_samples def _forward(data_samples): """ """ def inference(self, img, batch_img_metas): """ """ def encode_decode(self, inputs, batch_img_metas): """ """ def extract_feat(self, inputs): """ """ def loss(self, inputs, data_samples): """ """ def get_cls_idx(path): with open(path, 'r') as f: name_sets = f.readlines() num_cls = len(name_sets) class_names, class_indices = [], [] for idx in range(num_cls): names_i = name_sets[idx].split(', ') class_names += names_i class_indices += [idx for _ in range(len(names_i))] class_names = [item.replace('\n', '') for item in class_names] return class_names, class_indices ================================================ FILE: configs/base_config.py ================================================ # base configurations model = dict( type='CLIPForSegmentation', clip_path='ViT-B/16' ) test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) default_scope = 'mmseg' env_cfg = dict( cudnn_benchmark=True, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') log_processor = dict(by_epoch=False) log_level = 'INFO' load_from = None resume = False test_cfg = dict(type='TestLoop') default_hooks = dict( timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='SegVisualizationHook', interval=1)) ================================================ FILE: configs/cfg_ade20k.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_ade20k.txt' ) # dataset settings dataset_type = 'ADE20KDataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='images/validation', seg_map_path='annotations/validation'), pipeline=test_pipeline)) ================================================ FILE: configs/cfg_city_scapes.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_city_scapes.txt' ) # dataset settings dataset_type = 'CityscapesDataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 560), keep_ratio=True), # add loading annotation after ``Resize`` because ground truth # does not need to do resize data transform dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='leftImg8bit/val', seg_map_path='gtFine/val'), pipeline=test_pipeline)) ================================================ FILE: configs/cfg_coco_object.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_coco_object.txt', logit_scale=50, prob_thd=0.1 ) # dataset settings dataset_type = 'COCOObjectDataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), # add loading annotation after ``Resize`` because ground truth # does not need to do resize data transform dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, reduce_zero_label=False, data_prefix=dict( img_path='images/val2017', seg_map_path='annotations/val2017'), pipeline=test_pipeline)) ================================================ FILE: configs/cfg_coco_stuff10k.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_coco_stuff.txt' ) # dataset settings dataset_type = 'COCOStuffDataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, reduce_zero_label=True, data_prefix=dict( img_path='images/test2014', seg_map_path='annotations/test2014'), pipeline=test_pipeline)) ================================================ FILE: configs/cfg_coco_stuff164k.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_coco_stuff.txt' ) # dataset settings dataset_type = 'COCOStuffDataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 448), keep_ratio=True), dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='images/val2017', seg_map_path='annotations/val2017'), pipeline=test_pipeline)) ================================================ FILE: configs/cfg_context59.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_context59.txt' ) # dataset settings dataset_type = 'PascalContext59Dataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='JPEGImages', seg_map_path='SegmentationClassContext'), ann_file='ImageSets/SegmentationContext/val.txt', pipeline=test_pipeline)) ================================================ FILE: configs/cfg_context60.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_context60.txt', logit_scale=50, prob_thd=0.1 ) # dataset settings dataset_type = 'PascalContext60Dataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='JPEGImages', seg_map_path='SegmentationClassContext'), ann_file='ImageSets/SegmentationContext/val.txt', pipeline=test_pipeline)) ================================================ FILE: configs/cfg_voc20.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_voc20.txt' ) # dataset settings dataset_type = 'PascalVOC20Dataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='JPEGImages', seg_map_path='SegmentationClass'), ann_file='ImageSets/Segmentation/val.txt', pipeline=test_pipeline)) ================================================ FILE: configs/cfg_voc21.py ================================================ _base_ = './base_config.py' # model settings model = dict( name_path='./configs/cls_voc21.txt', logit_scale=65, prob_thd=0.1, area_thd=0.1 ) # dataset settings dataset_type = 'PascalVOCDataset' data_root = '' test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2048, 336), keep_ratio=True), dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, data_prefix=dict( img_path='JPEGImages', seg_map_path='SegmentationClass'), ann_file='ImageSets/Segmentation/val.txt', pipeline=test_pipeline)) ================================================ FILE: configs/cls_ade20k.txt ================================================ wall building sky floor tree ceiling road bed windowpane grass cabinet sidewalk person earth door table mountain plant curtain chair car water painting sofa shelf house sea mirror rug field armchair seat fence desk rock wardrobe lamp bathtub railing cushion base box column signboard chestofdrawers counter sand sink skyscraper fireplace refrigerator grandstand path stairs runway case pooltable pillow screendoor stairway river bridge bookcase blind coffeetable toilet flower book hill bench countertop stove palm kitchenisland computer swivelchair boat bar arcademachine hovel bus towel light truck tower chandelier awning streetlight booth televisionreceiver airplane dirttrack apparel pole land bannister escalator ottoman bottle buffet poster stage van ship fountain conveyerbelt canopy washer plaything swimmingpool stool barrel basket waterfall tent bag minibike cradle oven ball food step tank tradename microwave pot animal bicycle lake dishwasher screen blanket sculpture hood sconce vase trafficlight tray ashcan fan pier crtscreen plate monitor bulletinboard shower radiator glass clock flag ================================================ FILE: configs/cls_city_scapes.txt ================================================ road sidewalk building wall fence pole trafficlight trafficsign vegetation terrain sky person rider car truck bus train motorcycle bicycle ================================================ FILE: configs/cls_coco_object.txt ================================================ sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body bicycle car motorcycle airplane bus train truck boat traffic light fire hydrant stop sign parking meter bench bird cat dog horse sheep cow elephant bear zebra giraffe backpack umbrella handbag tie suitcase frisbee skis snowboard sports ball kite baseball bat baseball glove skateboard surfboard tennis racket bottle wine glass cup fork knife spoon bowl banana apple sandwich orange broccoli carrot hot dog pizza donut cake chair couch potted plant bed dining table toilet tv laptop mouse remote keyboard cell phone microwave oven toaster sink refrigerator book clock vase scissors teddy bear hair drier toothbrush ================================================ FILE: configs/cls_coco_stuff.txt ================================================ person bicycle car motorcycle airplane bus train truck boat trafficlight firehydrant stopsign parkingmeter bench bird cat dog horse sheep cow elephant bear zebra giraffe backpack umbrella handbag tie suitcase frisbee skis snowboard sportsball kite baseballbat baseballglove skateboard surfboard tennisracket bottle wineglass cup fork knife spoon bowl banana apple sandwich orange broccoli carrot hotdog pizza donut cake chair couch pottedplant bed diningtable toilet tv laptop mouse remote keyboard cellphone microwave oven toaster sink refrigerator book clock vase scissors teddybear hairdrier toothbrush banner blanket branch bridge building-other bush cabinet cage cardboard carpet ceiling-other ceiling-tile cloth clothes clouds counter cupboard curtain desk-stuff dirt door-stuff fence floor-marble floor-other floor-stone floor-tile floor-wood flower fog food-other fruit furniture-other grass gravel ground-other hill house leaves light mat metal mirror-stuff moss mountain mud napkin net paper pavement pillow plant-other plastic platform playingfield railing railroad river road rock roof rug salad sand sea shelf sky-other skyscraper snow solid-other stairs stone straw structural-other table tent textile-other towel tree vegetable wall-brick wall-concrete wall-other wall-panel wall-stone wall-tile wall-wood water-other waterdrops window-blind window-other wood ================================================ FILE: configs/cls_context59.txt ================================================ aeroplane bag bed bedclothes bench bicycle bird boat book bottle building bus cabinet car cat ceiling chair cloth computer cow cup curtain dog door fence floor flower food grass ground horse keyboard light motorbike mountain mouse person plate platform pottedplant road rock sheep shelves sidewalk sign sky snow sofa table track train tree truck tvmonitor wall water window wood ================================================ FILE: configs/cls_context60.txt ================================================ background aeroplane bag bed bedclothes bench bicycle bird boat book bottle building bus cabinet car cat ceiling chair cloth computer cow cup curtain dog door fence floor flower food grass ground horse keyboard light motorbike mountain mouse person plate platform pottedplant road rock sheep shelves sidewalk sign sky snow sofa table track train tree truck tvmonitor wall water window wood ================================================ FILE: configs/cls_voc20.txt ================================================ aeroplane bicycle bird ship bottle bus car cat chair cow table dog horse motorbike person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket pottedplant sheep sofa train television monitor, tv monitor, monitor, television, screen ================================================ FILE: configs/cls_voc21.txt ================================================ sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence aeroplane bicycle bird ship bottle bus car cat chair cow table dog horse motorbike person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket pottedplant sheep sofa train television monitor, tv monitor, monitor, television, screen ================================================ FILE: custom_datasets.py ================================================ import os.path as osp import mmengine.fileio as fileio from mmseg.registry import DATASETS from mmseg.datasets import BaseSegDataset @DATASETS.register_module() class PascalVOC20Dataset(BaseSegDataset): """Pascal VOC dataset. Args: split (str): Split txt file for Pascal VOC. """ METAINFO = dict( classes=('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'), palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192], [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]) def __init__(self, ann_file, img_suffix='.jpg', seg_map_suffix='.png', reduce_zero_label=True, **kwargs) -> None: super().__init__( img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, reduce_zero_label=reduce_zero_label, ann_file=ann_file, **kwargs) assert fileio.exists(self.data_prefix['img_path'], self.backend_args) and osp.isfile(self.ann_file) @DATASETS.register_module() class COCOObjectDataset(BaseSegDataset): """ Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT) COCO-Object dataset. 1 bg class + first 80 classes from the COCO-Stuff dataset. """ METAINFO = dict( classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]) def __init__(self, **kwargs): super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs) @DATASETS.register_module() class PascalContext60Dataset(BaseSegDataset): METAINFO = dict( classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood'), palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) def __init__(self, ann_file: str, img_suffix='.jpg', seg_map_suffix='.png', **kwargs) -> None: super().__init__( img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, ann_file=ann_file, reduce_zero_label=False, **kwargs) @DATASETS.register_module() class PascalContext59Dataset(BaseSegDataset): METAINFO = dict( classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood'), palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) def __init__(self, ann_file: str, img_suffix='.jpg', seg_map_suffix='.png', reduce_zero_label=True, **kwargs): super().__init__( img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, ann_file=ann_file, reduce_zero_label=reduce_zero_label, **kwargs) ================================================ FILE: datasets/cvt_coco_object.py ================================================ # ------------------------------------------------------------------------------ # GroupViT (https://github.com/NVlabs/GroupViT) # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved. # ------------------------------------------------------------------------------ import argparse import os.path as osp import shutil from functools import partial from glob import glob import mmcv import numpy as np from PIL import Image COCO_LEN = 123287 clsID_to_trID = { 0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 26: 24, 27: 25, 30: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 45: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 66: 60, 69: 61, 71: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 83: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 91: 80, 92: 81, 93: 82, 94: 83, 95: 84, 96: 85, 97: 86, 98: 87, 99: 88, 100: 89, 101: 90, 102: 91, 103: 92, 104: 93, 105: 94, 106: 95, 107: 96, 108: 97, 109: 98, 110: 99, 111: 100, 112: 101, 113: 102, 114: 103, 115: 104, 116: 105, 117: 106, 118: 107, 119: 108, 120: 109, 121: 110, 122: 111, 123: 112, 124: 113, 125: 114, 126: 115, 127: 116, 128: 117, 129: 118, 130: 119, 131: 120, 132: 121, 133: 122, 134: 123, 135: 124, 136: 125, 137: 126, 138: 127, 139: 128, 140: 129, 141: 130, 142: 131, 143: 132, 144: 133, 145: 134, 146: 135, 147: 136, 148: 137, 149: 138, 150: 139, 151: 140, 152: 141, 153: 142, 154: 143, 155: 144, 156: 145, 157: 146, 158: 147, 159: 148, 160: 149, 161: 150, 162: 151, 163: 152, 164: 153, 165: 154, 166: 155, 167: 156, 168: 157, 169: 158, 170: 159, 171: 160, 172: 161, 173: 162, 174: 163, 175: 164, 176: 165, 177: 166, 178: 167, 179: 168, 180: 169, 181: 170, 255: 255 } # set to background for k, v in clsID_to_trID.items(): clsID_to_trID[k] = v + 1 if k > 90: clsID_to_trID[k] = 0 def convert_to_trainID(maskpath, out_mask_dir, is_train): mask = np.array(Image.open(maskpath)) mask_copy = mask.copy() for clsID, trID in clsID_to_trID.items(): mask_copy[mask == clsID] = trID seg_filename = osp.join( out_mask_dir, 'train2017', osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') if is_train else osp.join( out_mask_dir, 'val2017', osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') Image.fromarray(mask_copy).save(seg_filename, 'PNG') def parse_args(): parser = argparse.ArgumentParser( description=\ 'Convert COCO Stuff 164k annotations to COCO Objects') # noqa parser.add_argument('coco_path', help='coco stuff path') parser.add_argument('-o', '--out_dir', help='output path') parser.add_argument( '--nproc', default=16, type=int, help='number of process') args = parser.parse_args() return args def main(): args = parse_args() coco_path = args.coco_path nproc = args.nproc out_dir = args.out_dir or coco_path out_img_dir = osp.join(out_dir, 'images') out_mask_dir = osp.join(out_dir, 'annotations') mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) if out_dir != coco_path: shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) train_list = [file for file in train_list if 'TrainIds' not in file] test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) test_list = [file for file in test_list if 'TrainIds' not in file] assert (len(train_list) + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( len(train_list), len(test_list)) if args.nproc > 1: mmcv.track_parallel_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), train_list, nproc=nproc) mmcv.track_parallel_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), test_list, nproc=nproc) else: mmcv.track_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), train_list) mmcv.track_progress( partial( convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), test_list) print('Done!') if __name__ == '__main__': main() ================================================ FILE: dist_test.sh ================================================ CONFIG=$1 WORK_DIR=${WORK_DIR:-"./work_logs"} GPUS=${GPUS:-4} NNODES=${NNODES:-1} NODE_RANK=${NODE_RANK:-0} PORT=${PORT:-29500} MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ python -m torch.distributed.launch \ --nnodes=$NNODES \ --node_rank=$NODE_RANK \ --master_addr=$MASTER_ADDR \ --nproc_per_node=$GPUS \ --master_port=$PORT \ $(dirname "$0")/eval.py \ --config $CONFIG \ --work-dir $WORK_DIR \ --launcher pytorch \ ${@:4} ================================================ FILE: eval.py ================================================ import os import argparse import clip_segmentor import custom_datasets from mmengine.config import Config from mmengine.runner import Runner def parse_args(): parser = argparse.ArgumentParser( description='SCLIP evaluation with MMSeg') parser.add_argument('--config', default='') parser.add_argument('--work-dir', default='./work_logs/') parser.add_argument( '--show', action='store_true', help='show prediction results') parser.add_argument( '--show_dir', default='', help='directory to save visualizaion images') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` # will pass the `--local-rank` parameter to `tools/train.py` instead # of `--local_rank`. parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args def trigger_visualization_hook(cfg, args): default_hooks = cfg.default_hooks if 'visualization' in default_hooks: visualization_hook = default_hooks['visualization'] # Turn on visualization visualization_hook['draw'] = True if args.show: visualization_hook['show'] = True visualization_hook['wait_time'] = args.wait_time if args.show_dir: visualizer = cfg.visualizer visualizer['save_dir'] = args.show_dir else: raise RuntimeError( 'VisualizationHook must be included in default_hooks.' 'refer to usage ' '"visualization=dict(type=\'VisualizationHook\')"') return cfg def main(): args = parse_args() cfg = Config.fromfile(args.config) cfg.launcher = args.launcher cfg.work_dir = args.work_dir runner = Runner.from_cfg(cfg) runner.test() if __name__ == '__main__': main() ================================================ FILE: pamr.py ================================================ # Copyright 2020 TU Darmstadt # Licnese: Apache 2.0 License. # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py import torch import torch.nn.functional as F import torch.nn as nn from functools import partial # # Helper modules # class LocalAffinity(nn.Module): def __init__(self, dilations=[1]): super(LocalAffinity, self).__init__() self.dilations = dilations weight = self._init_aff() self.register_buffer('kernel', weight) def _init_aff(self): # initialising the shift kernel weight = torch.zeros(8, 1, 3, 3) for i in range(weight.size(0)): weight[i, 0, 1, 1] = 1 weight[0, 0, 0, 0] = -1 weight[1, 0, 0, 1] = -1 weight[2, 0, 0, 2] = -1 weight[3, 0, 1, 0] = -1 weight[4, 0, 1, 2] = -1 weight[5, 0, 2, 0] = -1 weight[6, 0, 2, 1] = -1 weight[7, 0, 2, 2] = -1 self.weight_check = weight.clone() return weight def forward(self, x): self.weight_check = self.weight_check.type_as(x) assert torch.all(self.weight_check.eq(self.kernel)) B,K,H,W = x.size() x = x.view(B*K,1,H,W) x_affs = [] for d in self.dilations: x_pad = F.pad(x, [d]*4, mode='replicate') x_aff = F.conv2d(x_pad, self.kernel, dilation=d) x_affs.append(x_aff) x_aff = torch.cat(x_affs, 1) return x_aff.view(B,K,-1,H,W) class LocalAffinityCopy(LocalAffinity): def _init_aff(self): # initialising the shift kernel weight = torch.zeros(8, 1, 3, 3) weight[0, 0, 0, 0] = 1 weight[1, 0, 0, 1] = 1 weight[2, 0, 0, 2] = 1 weight[3, 0, 1, 0] = 1 weight[4, 0, 1, 2] = 1 weight[5, 0, 2, 0] = 1 weight[6, 0, 2, 1] = 1 weight[7, 0, 2, 2] = 1 self.weight_check = weight.clone() return weight class LocalStDev(LocalAffinity): def _init_aff(self): weight = torch.zeros(9, 1, 3, 3) weight.zero_() weight[0, 0, 0, 0] = 1 weight[1, 0, 0, 1] = 1 weight[2, 0, 0, 2] = 1 weight[3, 0, 1, 0] = 1 weight[4, 0, 1, 1] = 1 weight[5, 0, 1, 2] = 1 weight[6, 0, 2, 0] = 1 weight[7, 0, 2, 1] = 1 weight[8, 0, 2, 2] = 1 self.weight_check = weight.clone() return weight def forward(self, x): # returns (B,K,P,H,W), where P is the number # of locations x = super(LocalStDev, self).forward(x) return x.std(2, keepdim=True) class LocalAffinityAbs(LocalAffinity): def forward(self, x): x = super(LocalAffinityAbs, self).forward(x) return torch.abs(x) # # PAMR module # class PAMR(nn.Module): def __init__(self, num_iter=1, dilations=[1]): super(PAMR, self).__init__() self.num_iter = num_iter self.aff_x = LocalAffinityAbs(dilations) self.aff_m = LocalAffinityCopy(dilations) self.aff_std = LocalStDev(dilations) def forward(self, x, mask): mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) # x: [BxKxHxW] # mask: [BxCxHxW] B,K,H,W = x.size() _,C,_,_ = mask.size() x_std = self.aff_std(x) x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) x = x.mean(1, keepdim=True) x = F.softmax(x, 2) for _ in range(self.num_iter): m = self.aff_m(mask) # [BxCxPxHxW] mask = (m * x).sum(2) # xvals: [BxCxHxW] return mask ================================================ FILE: prompts/imagenet_template.py ================================================ imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] openai_imagenet_template = [ lambda c: f'a bad photo of a {c}.', lambda c: f'a photo of many {c}.', lambda c: f'a sculpture of a {c}.', lambda c: f'a photo of the hard to see {c}.', lambda c: f'a low resolution photo of the {c}.', lambda c: f'a rendering of a {c}.', lambda c: f'graffiti of a {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a cropped photo of the {c}.', lambda c: f'a tattoo of a {c}.', lambda c: f'the embroidered {c}.', lambda c: f'a photo of a hard to see {c}.', lambda c: f'a bright photo of a {c}.', lambda c: f'a photo of a clean {c}.', lambda c: f'a photo of a dirty {c}.', lambda c: f'a dark photo of the {c}.', lambda c: f'a drawing of a {c}.', lambda c: f'a photo of my {c}.', lambda c: f'the plastic {c}.', lambda c: f'a photo of the cool {c}.', lambda c: f'a close-up photo of a {c}.', lambda c: f'a black and white photo of the {c}.', lambda c: f'a painting of the {c}.', lambda c: f'a painting of a {c}.', lambda c: f'a pixelated photo of the {c}.', lambda c: f'a sculpture of the {c}.', lambda c: f'a bright photo of the {c}.', lambda c: f'a cropped photo of a {c}.', lambda c: f'a plastic {c}.', lambda c: f'a photo of the dirty {c}.', lambda c: f'a jpeg corrupted photo of a {c}.', lambda c: f'a blurry photo of the {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a good photo of the {c}.', lambda c: f'a rendering of the {c}.', lambda c: f'a {c} in a video game.', lambda c: f'a photo of one {c}.', lambda c: f'a doodle of a {c}.', lambda c: f'a close-up photo of the {c}.', lambda c: f'a photo of a {c}.', lambda c: f'the origami {c}.', lambda c: f'the {c} in a video game.', lambda c: f'a sketch of a {c}.', lambda c: f'a doodle of the {c}.', lambda c: f'a origami {c}.', lambda c: f'a low resolution photo of a {c}.', lambda c: f'the toy {c}.', lambda c: f'a rendition of the {c}.', lambda c: f'a photo of the clean {c}.', lambda c: f'a photo of a large {c}.', lambda c: f'a rendition of a {c}.', lambda c: f'a photo of a nice {c}.', lambda c: f'a photo of a weird {c}.', lambda c: f'a blurry photo of a {c}.', lambda c: f'a cartoon {c}.', lambda c: f'art of a {c}.', lambda c: f'a sketch of the {c}.', lambda c: f'a embroidered {c}.', lambda c: f'a pixelated photo of a {c}.', lambda c: f'itap of the {c}.', lambda c: f'a jpeg corrupted photo of the {c}.', lambda c: f'a good photo of a {c}.', lambda c: f'a plushie {c}.', lambda c: f'a photo of the nice {c}.', lambda c: f'a photo of the small {c}.', lambda c: f'a photo of the weird {c}.', lambda c: f'the cartoon {c}.', lambda c: f'art of the {c}.', lambda c: f'a drawing of the {c}.', lambda c: f'a photo of the large {c}.', lambda c: f'a black and white photo of a {c}.', lambda c: f'the plushie {c}.', lambda c: f'a dark photo of a {c}.', lambda c: f'itap of a {c}.', lambda c: f'graffiti of the {c}.', lambda c: f'a toy {c}.', lambda c: f'itap of my {c}.', lambda c: f'a photo of a cool {c}.', lambda c: f'a photo of a small {c}.', lambda c: f'a tattoo of the {c}.', ]