[
  {
    "path": "README.md",
    "content": "# SCLIP: Rethinking Self-Attention for Dense Vision-Language Inference\n\n**News: this paper has been accepted by ECCV 2024**\n\n**Official PyTorch implementation of SCLIP**\n\n* [SCLIP: Rethinking Self-Attention for Dense Vision-Language Inference](https://arxiv.org/pdf/2312.01597.pdf).\n* A **simple** but very effective open-vocabulary semantic segmentation model derived from CLIP.\n* **SOTA** zero-shot segmentation results obtained by minimal modifications to CLIP's self-attention.\n\n**Model components and our Correlative Self-Attention maps:**\n\n![sclip_0](figs/sclip_0.png)\n\n**Open-vocabulary semantic segmentation samples:**\n\n![sclip_1](figs/sclip_1.png)\n\n\n\n## Dependencies\n\nThis repo is built on top of [CLIP](https://github.com/openai/CLIP) and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). To run SCLIP, please install the following packages with your Pytorch environment. We recommend using Pytorch==1.10.x for better compatibility to the following MMSeg version.\n\n```\npip install openmim\nmim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1\npip install ftfy regex yapf==0.40.1\n```\n\n\n\n## Datasets\nWe include the following dataset configurations in this repo: PASCAL VOC, PASCAL Context, Cityscapes, ADE20k, COCO-Stuff10k, and COCO-Stuff164k, with three more variant datasets VOC20, Context59 (i.e., PASCAL VOC and PASCAL Context without the background category), and COCO-Object.\n\nPlease follow the [MMSeg data preparation document](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download and pre-process the datasets. The COCO-Object dataset can be converted from COCO-Stuff164k by executing the following command:\n\n```\npython datasets/cvt_coco_object.py PATH_TO_COCO_STUFF164K -o PATH_TO_COCO164K\n```\n\n**Remember to modify the dataset paths in the config files in** `config/cfg_DATASET.py`\n\n\n\n## Run SCLIP\nSingle-GPU running:\n\n```\npython eval.py --config ./configs/cfg_DATASET.py --workdir YOUR_WORK_DIR\n```\n\nMulti-GPU running:\n```\nbash ./dist_test.sh ./configs/cfg_DATASET.py\n```\n\n\n\n## Results\n\nThe performance of open-vocabulary inference can be affected by the text targets, i.e., the prompts and class names. This repo presents a easy way to explore them: you can modify prompts in `prompts/imagenet_template.py`, and class names in `configs/cls_DATASET.text`. \n\nThe repo automatically loads class names from the `configs/cls_DATASET.text` file. The rule of class names is that each category can have multiple class names, and these class names share one line in the file, separated by commas.\n\nWith the default setup in this repo, you should get the following results:\n\n| Dataset               | mIoU  |\n| --------------------- | ----- |\n| ADE20k                | 16.45 |\n| Cityscapes            | 32.34 |\n| COCO-Object           | 33.52 |\n| COCO-Stuff10k         | 25.91 |\n| COCO-Stuff164k        | 22.77 |\n| PASCAL Context59      | 34.46 |\n| PASCAL Context60      | 31.74 |\n| PASCAL VOC (w/o. bg.) | 81.54 |\n| PASCAL VOC (w. bg.)   | 59.63 |\n\n\n\n## Citation\n\n```\n@article{wang2023sclip,\n  title={SCLIP: Rethinking Self-Attention for Dense Vision-Language Inference},\n  author={Wang, Feng and Mei, Jieru and Yuille, Alan},\n  journal={arXiv preprint arXiv:2312.01597},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "clip/__init__.py",
    "content": "from .clip import *\nfrom .model import *\n"
  },
  {
    "path": "clip/clip.py",
    "content": "### CLIP source code from OpenAI:\n# https://github.com/openai/CLIP/blob/main/clip/clip.py\n\nimport hashlib\nimport os\nimport urllib\nimport warnings\nfrom typing import Any, Union, List\nfrom pkg_resources import packaging\n\nimport torch\nfrom PIL import Image\nfrom torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\nfrom tqdm import tqdm\n\nfrom .model import build_model\nfrom .simple_tokenizer import SimpleTokenizer as _Tokenizer\n\ntry:\n    from torchvision.transforms import InterpolationMode\n    BICUBIC = InterpolationMode.BICUBIC\nexcept ImportError:\n    BICUBIC = Image.BICUBIC\n\n\nif packaging.version.parse(torch.__version__) < packaging.version.parse(\"1.7.1\"):\n    warnings.warn(\"PyTorch version 1.7.1 or higher is recommended\")\n\n\n__all__ = [\"available_models\", \"load\", \"tokenize\"]\n_tokenizer = _Tokenizer()\n\n_MODELS = {\n    \"RN50\": \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n    \"RN101\": \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n    \"RN50x4\": \"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n    \"RN50x16\": \"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\",\n    \"ViT-B/32\": \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n    \"ViT-B/16\": \"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\",\n    \"ViT-L/14\": \"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\",\n    \"ViT-L/14@336px\": \"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\",\n}\n\n\ndef _download(url: str, root: str):\n    os.makedirs(root, exist_ok=True)\n    filename = os.path.basename(url)\n\n    expected_sha256 = url.split(\"/\")[-2]\n    download_target = os.path.join(root, filename)\n\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\n        raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n\n    if os.path.isfile(download_target):\n        if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest() == expected_sha256:\n            return download_target\n        else:\n            warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\n\n    with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n        with tqdm(total=int(source.info().get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:\n            while True:\n                buffer = source.read(8192)\n                if not buffer:\n                    break\n\n                output.write(buffer)\n                loop.update(len(buffer))\n\n    if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest() != expected_sha256:\n        raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\n\n    return download_target\n\n\ndef _convert_image_to_rgb(image):\n    return image.convert(\"RGB\")\n\n\ndef _transform(n_px):\n    return Compose([\n        Resize(n_px, interpolation=BICUBIC),\n        CenterCrop(n_px),\n        _convert_image_to_rgb,\n        ToTensor(),\n        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n    ])\n\n\ndef available_models() -> List[str]:\n    \"\"\"Returns the names of available CLIP models\"\"\"\n    return list(_MODELS.keys())\n\n\ndef load(name: str, device: Union[str, torch.device] = \"cuda\" if torch.cuda.is_available() else \"cpu\", jit: bool = False, download_root: str = None):\n    \"\"\"Load a CLIP model\n\n    Parameters\n    ----------\n    name : str\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\n\n    device : Union[str, torch.device]\n        The device to put the loaded model\n\n    jit : bool\n        Whether to load the optimized JIT model or more hackable non-JIT model (default).\n\n    download_root: str\n        path to download the model files; by default, it uses \"~/.cache/clip\"\n\n    Returns\n    -------\n    model : torch.nn.Module\n        The CLIP model\n\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n    \"\"\"\n    if name in _MODELS:\n        model_path = _download(_MODELS[name], download_root or os.path.expanduser(\"~/.cache/clip\"))\n    elif os.path.isfile(name):\n        model_path = name\n    else:\n        raise RuntimeError(f\"Model {name} not found; available models = {available_models()}\")\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=device if jit else \"cpu\").eval()\n        state_dict = None\n    except RuntimeError:\n        # loading saved state dict\n        if jit:\n            warnings.warn(f\"File {model_path} is not a JIT archive. Loading as a state dict instead\")\n            jit = False\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    if not jit:\n        model = build_model(state_dict or model.state_dict()).to(device)\n        if str(device) == \"cpu\":\n            model.float()\n        return model, _transform(model.visual.input_resolution)\n\n    # patch the device names\n    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])\n    device_node = [n for n in device_holder.graph.findAllNodes(\"prim::Constant\") if \"Device\" in repr(n)][-1]\n\n    def patch_device(module):\n        try:\n            graphs = [module.graph] if hasattr(module, \"graph\") else []\n        except RuntimeError:\n            graphs = []\n\n        if hasattr(module, \"forward1\"):\n            graphs.append(module.forward1.graph)\n\n        for graph in graphs:\n            for node in graph.findAllNodes(\"prim::Constant\"):\n                if \"value\" in node.attributeNames() and str(node[\"value\"]).startswith(\"cuda\"):\n                    node.copyAttributes(device_node)\n\n    model.apply(patch_device)\n    patch_device(model.encode_image)\n    patch_device(model.encode_text)\n\n    # patch dtype to float32 on CPU\n    if str(device) == \"cpu\":\n        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])\n        float_input = list(float_holder.graph.findNode(\"aten::to\").inputs())[1]\n        float_node = float_input.node()\n\n        def patch_float(module):\n            try:\n                graphs = [module.graph] if hasattr(module, \"graph\") else []\n            except RuntimeError:\n                graphs = []\n\n            if hasattr(module, \"forward1\"):\n                graphs.append(module.forward1.graph)\n\n            for graph in graphs:\n                for node in graph.findAllNodes(\"aten::to\"):\n                    inputs = list(node.inputs())\n                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()\n                        if inputs[i].node()[\"value\"] == 5:\n                            inputs[i].node().copyAttributes(float_node)\n\n        model.apply(patch_float)\n        patch_float(model.encode_image)\n        patch_float(model.encode_text)\n\n        model.float()\n\n    return model, _transform(model.input_resolution.item())\n\n\ndef tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:\n    \"\"\"\n    Returns the tokenized representation of given input string(s)\n\n    Parameters\n    ----------\n    texts : Union[str, List[str]]\n        An input string or a list of input strings to tokenize\n\n    context_length : int\n        The context length to use; all CLIP models use 77 as the context length\n\n    truncate: bool\n        Whether to truncate the text in case its encoding is longer than the context length\n\n    Returns\n    -------\n    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\n    \"\"\"\n    if isinstance(texts, str):\n        texts = [texts]\n\n    sot_token = _tokenizer.encoder[\"<|startoftext|>\"]\n    eot_token = _tokenizer.encoder[\"<|endoftext|>\"]\n    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        if len(tokens) > context_length:\n            if truncate:\n                tokens = tokens[:context_length]\n                tokens[-1] = eot_token\n            else:\n                raise RuntimeError(f\"Input {texts[i]} is too long for context length {context_length}\")\n        result[i, :len(tokens)] = torch.tensor(tokens)\n\n    return result"
  },
  {
    "path": "clip/model.py",
    "content": "### CLIP source code from OpenAI:\n# https://github.com/openai/CLIP/blob/main/clip/clip.py\n\nfrom collections import OrderedDict\nfrom typing import Tuple, Union\nimport math\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nimport torchvision.transforms.functional as VF\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = None\n        self.stride = stride\n\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n            self.downsample = nn.Sequential(OrderedDict([\n                (\"-1\", nn.AvgPool2d(stride)),\n                (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\n                (\"1\", nn.BatchNorm2d(planes * self.expansion))\n            ]))\n\n    def forward(self, x: torch.Tensor):\n        identity = x\n\n        out = self.relu(self.bn1(self.conv1(x)))\n        out = self.relu(self.bn2(self.conv2(out)))\n        out = self.avgpool(out)\n        out = self.bn3(self.conv3(out))\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n        return out\n\n\nclass AttentionPool2d(nn.Module):\n    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x, return_all_tokens=False):\n        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x, key=x, value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False\n        )\n        if return_all_tokens:\n            return x\n        else:\n            return x[0]\n\n\nclass ModifiedResNet(nn.Module):\n    \"\"\"\n    A ResNet class that is similar to torchvision's but contains the following changes:\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n    - The final pooling layer is a QKV attention instead of an average pool\n    \"\"\"\n\n    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):\n        super().__init__()\n        self.output_dim = output_dim\n        self.input_resolution = input_resolution\n\n        # the 3-layer stem\n        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(width // 2)\n        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(width // 2)\n        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(width)\n        self.avgpool = nn.AvgPool2d(2)\n        self.relu = nn.ReLU(inplace=True)\n\n        # residual layers\n        self._inplanes = width  # this is a *mutable* variable used during construction\n        self.layer1 = self._make_layer(width, layers[0])\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n\n        embed_dim = width * 32  # the ResNet feature dimension\n        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)\n\n    def _make_layer(self, planes, blocks, stride=1):\n        layers = [Bottleneck(self._inplanes, planes, stride)]\n\n        self._inplanes = planes * Bottleneck.expansion\n        for _ in range(1, blocks):\n            layers.append(Bottleneck(self._inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x, return_all_tokens=False):\n        def stem(x):\n            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:\n                x = self.relu(bn(conv(x)))\n            x = self.avgpool(x)\n            return x\n\n        x = x.type(self.conv1.weight.dtype)\n        x = stem(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.attnpool(x, return_all_tokens)\n\n        return x\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        ret = super().forward(x.type(torch.float32))\n        return ret.type(orig_type)\n\n\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):\n        super().__init__()\n\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, d_model * 4)),\n            (\"gelu\", QuickGELU()),\n            (\"c_proj\", nn.Linear(d_model * 4, d_model))\n        ]))\n        self.ln_2 = LayerNorm(d_model)\n        self.attn_mask = attn_mask\n\n    def attention(self, x: torch.Tensor):\n        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None\n        # pdb.set_trace()\n        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]\n\n    def forward(self, x: torch.Tensor):\n        x = x + self.attention(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])\n\n    def forward(self, x: torch.Tensor):\n        return self.resblocks(x)\n    \n\nclass VisionTransformer(nn.Module):\n    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.patch_size = patch_size\n        self.output_dim = output_dim\n        self.width = width\n        self.heads = heads\n        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)\n\n        scale = width ** -0.5\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\n        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))\n        self.ln_pre = LayerNorm(width)\n\n        self.transformer = Transformer(width, layers, heads)\n\n        self.ln_post = LayerNorm(width)\n        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))\n    \n    def forward(self, x: torch.Tensor, return_all=False, csa=True):\n\n        B, nc, w, h = x.shape\n\n        x = self.conv1(x)  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n\n        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]\n        \n        if x.shape[1] != self.positional_embedding.shape[0]:\n            x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype)\n        else:\n            x = x + self.positional_embedding.to(x.dtype)\n\n        x = self.ln_pre(x)           \n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        for blk in self.transformer.resblocks[:-1]:\n            x = blk(x)\n        for blk in self.transformer.resblocks[-1:]:\n            x = x + self.custom_attn(blk.attn, blk.ln_1(x), csa=csa)\n            x = x + blk.mlp(blk.ln_2(x))\n        x = x.permute(1, 0, 2)  # LND -> NLD\n            \n        if return_all:\n            return self.ln_post(x) @ self.proj\n\n        x = self.ln_post(x[:, 0, :])\n        if self.proj is not None:\n            x = x @ self.proj\n\n        return x\n    \n    def interpolate_pos_encoding(self, x, w, h):\n        npatch = x.shape[1] - 1\n        N = self.positional_embedding.shape[0] - 1\n        if npatch == N and w == h:\n            return self.positional_embedding\n        class_pos_embed = self.positional_embedding[[0]]\n        patch_pos_embed = self.positional_embedding[1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_size\n        h0 = h // self.patch_size\n        w0, h0 = w0 + 0.1, h0 + 0.1\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n            mode='bicubic',\n        )\n        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n    \n    def custom_attn(self, attn_layer, x, return_attn=False, with_attn=False, csa=False):\n        \n        num_heads = attn_layer.num_heads\n        _, bsz, embed_dim = x.size()\n        head_dim = embed_dim // num_heads\n        scale = head_dim ** -0.5\n\n        q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1)\n        q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n\n        if csa:\n            q_attn = torch.bmm(q, q.transpose(1, 2)) * scale\n            k_attn = torch.bmm(k, k.transpose(1, 2)) * scale\n            attn_weights = F.softmax(q_attn, dim=-1) + F.softmax(k_attn, dim=-1)\n        else:\n            attn_weights = torch.bmm(q * scale, k.transpose(1, 2))\n            attn_weights = F.softmax(attn_weights, dim=-1)\n\n        if return_attn:\n            return attn_weights\n\n        attn_output = torch.bmm(attn_weights, v)\n        attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim)\n        attn_output = attn_layer.out_proj(attn_output)\n\n        if with_attn:\n            return attn_output, attn_weights\n\n        return attn_output\n    \n    def get_attn(self, x, layer='all', csa=False):\n\n        B, nc, w, h = x.shape\n\n        x = self.conv1(x.type(self.conv1.weight.dtype))  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n\n        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]\n        \n        if x.shape[1] != self.positional_embedding.shape[0]:\n            x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype)\n        else:\n            x = x + self.positional_embedding.to(x.dtype)\n\n        x = self.ln_pre(x)\n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n\n        if layer == 'final':\n            for blk in self.transformer.resblocks[:-1]:\n                x = blk(x)\n            attn_map = self.custom_attn(self.transformer.resblocks[-1].attn,\n                                        self.transformer.resblocks[-1].ln_1(x),\n                                        csa=csa, return_attn=True)\n            return attn_map\n        elif layer == 'all':\n            attn_map = []\n            for blk in self.transformer.resblocks[:-1]:\n                x_i, attn_i = self.custom_attn(blk.attn, blk.ln_1(x), with_attn=True)\n                x = x + x_i\n                x = x + blk.mlp(blk.ln_2(x))\n                attn_map.append(attn_i)\n            for blk in self.transformer.resblocks[-1:]:\n                x_i, attn_i = self.custom_attn(blk.attn, blk.ln_1(x), with_attn=True, csa=True)\n                x = x + x_i\n                x = x + blk.mlp(blk.ln_2(x))\n                attn_map.append(attn_i)\n            return attn_map\n        else:\n            raise ValueError('layer should be final or all')\n\n\nclass CLIP(nn.Module):\n    def __init__(self,\n                 embed_dim: int, # 512\n                 # vision\n                 image_resolution: int, # 224\n                 vision_layers: Union[Tuple[int, int, int, int], int], # 12\n                 vision_width: int, # 768\n                 vision_patch_size: int, # 16\n                 # text\n                 context_length: int, # 77\n                 vocab_size: int, # 49408\n                 transformer_width: int, # 512\n                 transformer_heads: int, # 8\n                 transformer_layers: int # 12\n                 ):\n        super().__init__()\n        self.context_length = context_length\n\n        if isinstance(vision_layers, (tuple, list)):\n            vision_heads = vision_width * 32 // 64\n            self.visual = ModifiedResNet(\n                layers=vision_layers,\n                output_dim=embed_dim,\n                heads=vision_heads,\n                input_resolution=image_resolution,\n                width=vision_width\n            )\n        else:\n            vision_heads = vision_width // 64\n            self.visual = VisionTransformer(\n                input_resolution=image_resolution,\n                patch_size=vision_patch_size,\n                width=vision_width,\n                layers=vision_layers,\n                heads=vision_heads,\n                output_dim=embed_dim\n            )\n\n        self.transformer = Transformer(\n            width=transformer_width,\n            layers=transformer_layers,\n            heads=transformer_heads,\n            attn_mask=self.build_attention_mask()\n        )\n\n        self.vocab_size = vocab_size\n        self.token_embedding = nn.Embedding(vocab_size, transformer_width)\n        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))\n        self.ln_final = LayerNorm(transformer_width)\n\n        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n        self.initialize_parameters()\n\n    def initialize_parameters(self):\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.positional_embedding, std=0.01)\n\n        if isinstance(self.visual, ModifiedResNet):\n            if self.visual.attnpool is not None:\n                std = self.visual.attnpool.c_proj.in_features ** -0.5\n                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)\n\n            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:\n                for name, param in resnet_block.named_parameters():\n                    if name.endswith(\"bn3.weight\"):\n                        nn.init.zeros_(param)\n\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        attn_std = self.transformer.width ** -0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    @property\n    def dtype(self):\n        return self.visual.conv1.weight.dtype\n\n    def encode_image(self, image, return_all=False, csa=False):\n        return self.visual(image.type(self.dtype), return_all=return_all, csa=csa)\n\n    def encode_text(self, text):\n        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]\n\n        x = x + self.positional_embedding.type(self.dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x).type(self.dtype)\n\n        return x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\n\n    def forward(self, image, text):\n        image_features = self.encode_image(image)\n        text_features = self.encode_text(text)\n\n        # normalized features\n        image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n\n        # cosine similarity as logits\n        logit_scale = self.logit_scale.exp()\n        logits_per_image = logit_scale * image_features @ text_features.t()\n        logits_per_text = logits_per_image.t()\n\n        # shape = [global_batch_size, global_batch_size]\n        return logits_per_image, logits_per_text\n\ndef convert_weights(model: nn.Module):\n    \"\"\"Convert applicable model parameters to fp16\"\"\"\n\n    def _convert_weights_to_fp16(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.half()\n            if l.bias is not None:\n                l.bias.data = l.bias.data.half()\n\n        if isinstance(l, nn.MultiheadAttention):\n            for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n                tensor = getattr(l, attr)\n                if tensor is not None:\n                    tensor.data = tensor.data.half()\n\n        for name in [\"text_projection\", \"proj\"]:\n            if hasattr(l, name):\n                attr = getattr(l, name)\n                if attr is not None:\n                    attr.data = attr.data.half()\n\n    model.apply(_convert_weights_to_fp16)\n\ndef build_model(state_dict: dict):\n    vit = \"visual.proj\" in state_dict\n\n    if vit:\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n        vision_layers = len([k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n        grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\n        image_resolution = vision_patch_size * grid_size\n    else:\n        counts: list = [len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\n        vision_layers = tuple(counts)\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n        output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\n        vision_patch_size = None\n        assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n        image_resolution = output_width * 32\n\n    embed_dim = state_dict[\"text_projection\"].shape[1]\n    context_length = state_dict[\"positional_embedding\"].shape[0]\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n    transformer_heads = transformer_width // 64\n    transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\n\n    model = CLIP(\n        embed_dim,\n        image_resolution, vision_layers, vision_width, vision_patch_size,\n        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers\n    )\n\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n        if key in state_dict:\n            del state_dict[key]\n\n    convert_weights(model)\n    model.load_state_dict(state_dict)\n    return model.eval()"
  },
  {
    "path": "clip/simple_tokenizer.py",
    "content": "### CLIP source code from OpenAI:\n# https://github.com/openai/CLIP/blob/main/clip/clip.py\n\nimport gzip\nimport html\nimport os\nfrom functools import lru_cache\n\nimport ftfy\nimport regex as re\n\n\n@lru_cache()\ndef default_bpe():\n    return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"bpe_simple_vocab_16e6.txt.gz\")\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"Return set of symbol pairs in a word.\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r'\\s+', ' ', text)\n    text = text.strip()\n    return text\n\n\nclass SimpleTokenizer(object):\n    def __init__(self, bpe_path: str = default_bpe()):\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\n        merges = merges[1:49152-256-2+1]\n        merges = [tuple(merge.split()) for merge in merges]\n        vocab = list(bytes_to_unicode().values())\n        vocab = vocab + [v+'</w>' for v in vocab]\n        for merge in merges:\n            vocab.append(''.join(merge))\n        vocab.extend(['<|startoftext|>', '<|endoftext|>'])\n        self.encoder = dict(zip(vocab, range(len(vocab))))\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\n        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n        self.pat = re.compile(r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token[:-1]) + ( token[-1] + '</w>',)\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token+'</w>'\n\n        while True:\n            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                    new_word.extend(word[i:j])\n                    i = j\n                except:\n                    new_word.extend(word[i:])\n                    break\n\n                if word[i] == first and i < len(word)-1 and word[i+1] == second:\n                    new_word.append(first+second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = ' '.join(word)\n        self.cache[token] = word\n        return word\n\n    def encode(self, text):\n        bpe_tokens = []\n        text = whitespace_clean(basic_clean(text)).lower()\n        for token in re.findall(self.pat, text):\n            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n        return bpe_tokens\n\n    def decode(self, tokens):\n        text = ''.join([self.decoder[token] for token in tokens])\n        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n        return text\n"
  },
  {
    "path": "clip_segmentor.py",
    "content": "import torch\nimport torch.nn as nn\nimport sys \nsys.path.append(\"..\")\n\nimport clip\nfrom prompts.imagenet_template import openai_imagenet_template\n\nfrom mmseg.models.segmentors import BaseSegmentor\nfrom mmseg.models.data_preprocessor import SegDataPreProcessor\nfrom mmengine.structures import PixelData\n\nfrom mmseg.registry import MODELS\n\nfrom pamr import PAMR\n\n@MODELS.register_module()\nclass CLIPForSegmentation(BaseSegmentor):\n    def __init__(self, clip_path, name_path, device=torch.device('cuda'),\n                    pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40, \n                    slide_stride=112, slide_crop=224, area_thd=None):\n        \n        data_preprocessor = SegDataPreProcessor(\n            mean=[122.771, 116.746, 104.094],\n            std=[68.501, 66.632, 70.323],\n            rgb_to_bgr=True)\n        super().__init__(data_preprocessor=data_preprocessor)\n        self.net, _ = clip.load(clip_path, device=device, jit=False)\n        \n        query_words, self.query_idx = get_cls_idx(name_path)\n        self.num_queries = len(query_words)\n        self.num_classes = max(self.query_idx) + 1\n        self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device)\n\n        query_features = []\n        with torch.no_grad():\n            for qw in query_words:\n                query = clip.tokenize([temp(qw) for temp in openai_imagenet_template]).to(device)\n                feature = self.net.encode_text(query)\n                feature /= feature.norm(dim=-1, keepdim=True)\n                feature = feature.mean(dim=0)\n                feature /= feature.norm()\n                query_features.append(feature.unsqueeze(0))\n        self.query_features = torch.cat(query_features, dim=0)\n        \n        self.dtype = self.query_features.dtype\n        self.logit_scale = logit_scale\n        self.prob_thd = prob_thd\n        self.area_thd = area_thd\n        self.slide_stride = slide_stride\n        self.slide_crop = slide_crop\n        self.align_corners = False\n\n        if pamr_steps > 0:\n            self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device)\n        else:\n            self.pamr = None\n\n    def forward_feature(self, img, logit_size=None):\n        if type(img) == list:\n            img = img[0]\n\n        image_features = self.net.encode_image(img, return_all=True, csa=True)\n        image_features /= image_features.norm(dim=-1, keepdim=True)\n        image_features = image_features[:, 1:]\n        logits = image_features @ self.query_features.T\n\n        patch_size = self.net.visual.patch_size\n        w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size\n        out_dim = logits.shape[-1]\n        logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h)\n\n        if logit_size == None:\n            logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear')\n        else:\n            logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear')\n        \n        return logits\n\n    def forward_slide(self, img, img_metas, stride=112, crop_size=224):\n        \"\"\"Inference by sliding-window with overlap.\n        If h_crop > h_img or w_crop > w_img, the small patch will be used to\n        decode without padding.\n        \"\"\"\n        if type(img) == list:\n            img = img[0].unsqueeze(0)\n        if type(stride) == int:\n            stride = (stride, stride)\n        if type(crop_size) == int:\n            crop_size = (crop_size, crop_size)\n\n        h_stride, w_stride = stride\n        h_crop, w_crop = crop_size\n        batch_size, _, h_img, w_img = img.shape\n        out_channels = self.num_queries\n        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1\n        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1\n        preds = img.new_zeros((batch_size, out_channels, h_img, w_img))\n        count_mat = img.new_zeros((batch_size, 1, h_img, w_img))\n        for h_idx in range(h_grids):\n            for w_idx in range(w_grids):\n                y1 = h_idx * h_stride\n                x1 = w_idx * w_stride\n                y2 = min(y1 + h_crop, h_img)\n                x2 = min(x1 + w_crop, w_img)\n                y1 = max(y2 - h_crop, 0)\n                x1 = max(x2 - w_crop, 0)\n                crop_img = img[:, :, y1:y2, x1:x2]\n                crop_seg_logit = self.forward_feature(crop_img)\n                preds += nn.functional.pad(crop_seg_logit,\n                               (int(x1), int(preds.shape[3] - x2), int(y1),\n                                int(preds.shape[2] - y2)))\n\n                count_mat[:, :, y1:y2, x1:x2] += 1\n        assert (count_mat == 0).sum() == 0\n\n        preds = preds / count_mat\n        img_size = img_metas[0]['ori_shape'][:2]\n        logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear')\n\n        if self.pamr:\n            img = nn.functional.interpolate(img, size=img_size, mode='bilinear')\n            logits = self.pamr(img, logits.to(img.dtype)).to(self.dtype)\n\n        return logits\n\n    def predict(self, inputs, data_samples):\n        if data_samples is not None:\n            batch_img_metas = [\n                data_sample.metainfo for data_sample in data_samples\n            ]\n        else:\n            batch_img_metas = [\n                dict(\n                    ori_shape=inputs.shape[2:],\n                    img_shape=inputs.shape[2:],\n                    pad_shape=inputs.shape[2:],\n                    padding_size=[0, 0, 0, 0])\n            ] * inputs.shape[0]\n        \n        if self.slide_crop > 0:\n            seg_logits = self.forward_slide(inputs, batch_img_metas, self.slide_stride, self.slide_crop)\n        else:\n            seg_logits = self.forward_feature(inputs, batch_img_metas[0]['ori_shape'])\n\n        return self.postprocess_result(seg_logits, data_samples)\n    \n    def postprocess_result(self, seg_logits, data_samples):\n        batch_size = seg_logits.shape[0]\n        for i in range(batch_size):\n            seg_logits = seg_logits[i] * self.logit_scale\n            seg_logits = seg_logits.softmax(0) # n_queries * w * h\n\n            num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx)\n            if num_cls != num_queries:\n                seg_logits = seg_logits.unsqueeze(0)\n                cls_index = nn.functional.one_hot(self.query_idx)\n                cls_index = cls_index.T.view(num_cls, num_queries, 1, 1)\n                seg_logits = (seg_logits * cls_index).max(1)[0]\n                seg_pred = seg_logits.argmax(0, keepdim=True)\n\n            if self.area_thd is not None:\n                # Force segmentations with area < self.area_thd to 0 (background)\n                predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype)\n                area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True)  # prone background\n                area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype)          \n                seg_logits[1:] *= area_pred.transpose(0, -1)\n            \n            seg_pred = seg_logits.argmax(0, keepdim=True)\n            seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0\n            \n            data_samples[i].set_data({\n                'seg_logits':\n                PixelData(**{'data': seg_logits}),\n                'pred_sem_seg':\n                PixelData(**{'data': seg_pred})\n            })\n\n        return data_samples\n    \n    def _forward(data_samples):\n        \"\"\"\n        \"\"\"\n    \n    def inference(self, img, batch_img_metas):\n        \"\"\"\n        \"\"\"\n\n    def encode_decode(self, inputs, batch_img_metas):\n        \"\"\"\n        \"\"\"\n    \n    def extract_feat(self, inputs):\n        \"\"\"\n        \"\"\"\n    \n    def loss(self, inputs, data_samples):\n        \"\"\"\n        \"\"\"\n\ndef get_cls_idx(path):\n    with open(path, 'r') as f:\n        name_sets = f.readlines()\n    num_cls = len(name_sets)\n\n    class_names, class_indices = [], []\n    for idx in range(num_cls):\n        names_i = name_sets[idx].split(', ')\n        class_names += names_i\n        class_indices += [idx for _ in range(len(names_i))]\n    class_names = [item.replace('\\n', '') for item in class_names]\n    return class_names, class_indices"
  },
  {
    "path": "configs/base_config.py",
    "content": "# base configurations\nmodel = dict(\n    type='CLIPForSegmentation',\n    clip_path='ViT-B/16'\n)\n\ntest_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])\n\ndefault_scope = 'mmseg'\nenv_cfg = dict(\n    cudnn_benchmark=True,\n    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n    dist_cfg=dict(backend='nccl'),\n)\nvis_backends = [dict(type='LocalVisBackend')]\nvisualizer = dict(\n    type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')\nlog_processor = dict(by_epoch=False)\nlog_level = 'INFO'\nload_from = None\nresume = False\n\ntest_cfg = dict(type='TestLoop')\n\ndefault_hooks = dict(\n    timer=dict(type='IterTimerHook'),\n    logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),\n    param_scheduler=dict(type='ParamSchedulerHook'),\n    checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),\n    sampler_seed=dict(type='DistSamplerSeedHook'),\n    visualization=dict(type='SegVisualizationHook', interval=1))"
  },
  {
    "path": "configs/cfg_ade20k.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_ade20k.txt'\n)\n\n# dataset settings\ndataset_type = 'ADE20KDataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='images/validation',\n            seg_map_path='annotations/validation'),\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_city_scapes.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_city_scapes.txt'\n)\n\n# dataset settings\ndataset_type = 'CityscapesDataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 560), keep_ratio=True),\n    # add loading annotation after ``Resize`` because ground truth\n    # does not need to do resize data transform\n    dict(type='LoadAnnotations'),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='leftImg8bit/val', seg_map_path='gtFine/val'),\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_coco_object.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_coco_object.txt',\n    logit_scale=50,\n    prob_thd=0.1\n)\n\n# dataset settings\ndataset_type = 'COCOObjectDataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    # add loading annotation after ``Resize`` because ground truth\n    # does not need to do resize data transform\n    dict(type='LoadAnnotations'),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        reduce_zero_label=False,\n        data_prefix=dict(\n            img_path='images/val2017', seg_map_path='annotations/val2017'),\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_coco_stuff10k.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_coco_stuff.txt'\n)\n\n# dataset settings\ndataset_type = 'COCOStuffDataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        reduce_zero_label=True,\n        data_prefix=dict(\n            img_path='images/test2014', seg_map_path='annotations/test2014'),\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_coco_stuff164k.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_coco_stuff.txt'\n)\n\n# dataset settings\ndataset_type = 'COCOStuffDataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 448), keep_ratio=True),\n    dict(type='LoadAnnotations'),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='images/val2017', seg_map_path='annotations/val2017'),\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_context59.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_context59.txt'\n)\n\n# dataset settings\ndataset_type = 'PascalContext59Dataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    dict(type='LoadAnnotations', reduce_zero_label=True),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='JPEGImages', seg_map_path='SegmentationClassContext'),\n        ann_file='ImageSets/SegmentationContext/val.txt',\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_context60.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_context60.txt',\n    logit_scale=50,\n    prob_thd=0.1\n)\n\n# dataset settings\ndataset_type = 'PascalContext60Dataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    dict(type='LoadAnnotations'),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='JPEGImages', seg_map_path='SegmentationClassContext'),\n        ann_file='ImageSets/SegmentationContext/val.txt',\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_voc20.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_voc20.txt'\n)\n\n# dataset settings\ndataset_type = 'PascalVOC20Dataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    dict(type='LoadAnnotations'),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='JPEGImages', seg_map_path='SegmentationClass'),\n        ann_file='ImageSets/Segmentation/val.txt',\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cfg_voc21.py",
    "content": "_base_ = './base_config.py'\n\n# model settings\nmodel = dict(\n    name_path='./configs/cls_voc21.txt',\n    logit_scale=65,\n    prob_thd=0.1,\n    area_thd=0.1\n)\n\n# dataset settings\ndataset_type = 'PascalVOCDataset'\ndata_root = ''\n\ntest_pipeline = [\n    dict(type='LoadImageFromFile'),\n    dict(type='Resize', scale=(2048, 336), keep_ratio=True),\n    dict(type='LoadAnnotations'),\n    dict(type='PackSegInputs')\n]\n\ntest_dataloader = dict(\n    batch_size=1,\n    num_workers=4,\n    persistent_workers=True,\n    sampler=dict(type='DefaultSampler', shuffle=False),\n    dataset=dict(\n        type=dataset_type,\n        data_root=data_root,\n        data_prefix=dict(\n            img_path='JPEGImages', seg_map_path='SegmentationClass'),\n        ann_file='ImageSets/Segmentation/val.txt',\n        pipeline=test_pipeline))"
  },
  {
    "path": "configs/cls_ade20k.txt",
    "content": "wall\nbuilding\nsky\nfloor\ntree\nceiling\nroad\nbed\nwindowpane\ngrass\ncabinet\nsidewalk\nperson\nearth\ndoor\ntable\nmountain\nplant\ncurtain\nchair\ncar\nwater\npainting\nsofa\nshelf\nhouse\nsea\nmirror\nrug\nfield\narmchair\nseat\nfence\ndesk\nrock\nwardrobe\nlamp\nbathtub\nrailing\ncushion\nbase\nbox\ncolumn\nsignboard\nchestofdrawers\ncounter\nsand\nsink\nskyscraper\nfireplace\nrefrigerator\ngrandstand\npath\nstairs\nrunway\ncase\npooltable\npillow\nscreendoor\nstairway\nriver\nbridge\nbookcase\nblind\ncoffeetable\ntoilet\nflower\nbook\nhill\nbench\ncountertop\nstove\npalm\nkitchenisland\ncomputer\nswivelchair\nboat\nbar\narcademachine\nhovel\nbus\ntowel\nlight\ntruck\ntower\nchandelier\nawning\nstreetlight\nbooth\ntelevisionreceiver\nairplane\ndirttrack\napparel\npole\nland\nbannister\nescalator\nottoman\nbottle\nbuffet\nposter\nstage\nvan\nship\nfountain\nconveyerbelt\ncanopy\nwasher\nplaything\nswimmingpool\nstool\nbarrel\nbasket\nwaterfall\ntent\nbag\nminibike\ncradle\noven\nball\nfood\nstep\ntank\ntradename\nmicrowave\npot\nanimal\nbicycle\nlake\ndishwasher\nscreen\nblanket\nsculpture\nhood\nsconce\nvase\ntrafficlight\ntray\nashcan\nfan\npier\ncrtscreen\nplate\nmonitor\nbulletinboard\nshower\nradiator\nglass\nclock\nflag"
  },
  {
    "path": "configs/cls_city_scapes.txt",
    "content": "road\nsidewalk\nbuilding\nwall\nfence\npole\ntrafficlight\ntrafficsign\nvegetation\nterrain\nsky\nperson\nrider\ncar\ntruck\nbus\ntrain\nmotorcycle\nbicycle"
  },
  {
    "path": "configs/cls_coco_object.txt",
    "content": "sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence\nperson, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body\nbicycle\ncar\nmotorcycle\nairplane\nbus\ntrain\ntruck\nboat\ntraffic light\nfire hydrant\nstop sign\nparking meter\nbench\nbird\ncat\ndog\nhorse\nsheep\ncow\nelephant\nbear\nzebra\ngiraffe\nbackpack\numbrella\nhandbag\ntie\nsuitcase\nfrisbee\nskis\nsnowboard\nsports ball\nkite\nbaseball bat\nbaseball glove\nskateboard\nsurfboard\ntennis racket\nbottle\nwine glass\ncup\nfork\nknife\nspoon\nbowl\nbanana\napple\nsandwich\norange\nbroccoli\ncarrot\nhot dog\npizza\ndonut\ncake\nchair\ncouch\npotted plant\nbed\ndining table\ntoilet\ntv\nlaptop\nmouse\nremote\nkeyboard\ncell phone\nmicrowave\noven\ntoaster\nsink\nrefrigerator\nbook\nclock\nvase\nscissors\nteddy bear\nhair drier\ntoothbrush"
  },
  {
    "path": "configs/cls_coco_stuff.txt",
    "content": "person\nbicycle\ncar\nmotorcycle\nairplane\nbus\ntrain\ntruck\nboat\ntrafficlight\nfirehydrant\nstopsign\nparkingmeter\nbench\nbird\ncat\ndog\nhorse\nsheep\ncow\nelephant\nbear\nzebra\ngiraffe\nbackpack\numbrella\nhandbag\ntie\nsuitcase\nfrisbee\nskis\nsnowboard\nsportsball\nkite\nbaseballbat\nbaseballglove\nskateboard\nsurfboard\ntennisracket\nbottle\nwineglass\ncup\nfork\nknife\nspoon\nbowl\nbanana\napple\nsandwich\norange\nbroccoli\ncarrot\nhotdog\npizza\ndonut\ncake\nchair\ncouch\npottedplant\nbed\ndiningtable\ntoilet\ntv\nlaptop\nmouse\nremote\nkeyboard\ncellphone\nmicrowave\noven\ntoaster\nsink\nrefrigerator\nbook\nclock\nvase\nscissors\nteddybear\nhairdrier\ntoothbrush\nbanner\nblanket\nbranch\nbridge\nbuilding-other\nbush\ncabinet\ncage\ncardboard\ncarpet\nceiling-other\nceiling-tile\ncloth\nclothes\nclouds\ncounter\ncupboard\ncurtain\ndesk-stuff\ndirt\ndoor-stuff\nfence\nfloor-marble\nfloor-other\nfloor-stone\nfloor-tile\nfloor-wood\nflower\nfog\nfood-other\nfruit\nfurniture-other\ngrass\ngravel\nground-other\nhill\nhouse\nleaves\nlight\nmat\nmetal\nmirror-stuff\nmoss\nmountain\nmud\nnapkin\nnet\npaper\npavement\npillow\nplant-other\nplastic\nplatform\nplayingfield\nrailing\nrailroad\nriver\nroad\nrock\nroof\nrug\nsalad\nsand\nsea\nshelf\nsky-other\nskyscraper\nsnow\nsolid-other\nstairs\nstone\nstraw\nstructural-other\ntable\ntent\ntextile-other\ntowel\ntree\nvegetable\nwall-brick\nwall-concrete\nwall-other\nwall-panel\nwall-stone\nwall-tile\nwall-wood\nwater-other\nwaterdrops\nwindow-blind\nwindow-other\nwood"
  },
  {
    "path": "configs/cls_context59.txt",
    "content": "aeroplane\nbag\nbed\nbedclothes\nbench\nbicycle\nbird\nboat\nbook\nbottle\nbuilding\nbus\ncabinet\ncar\ncat\nceiling\nchair\ncloth\ncomputer\ncow\ncup\ncurtain\ndog\ndoor\nfence\nfloor\nflower\nfood\ngrass\nground\nhorse\nkeyboard\nlight\nmotorbike\nmountain\nmouse\nperson\nplate\nplatform\npottedplant\nroad\nrock\nsheep\nshelves\nsidewalk\nsign\nsky\nsnow\nsofa\ntable\ntrack\ntrain\ntree\ntruck\ntvmonitor\nwall\nwater\nwindow\nwood"
  },
  {
    "path": "configs/cls_context60.txt",
    "content": "background\naeroplane\nbag\nbed\nbedclothes\nbench\nbicycle\nbird\nboat\nbook\nbottle\nbuilding\nbus\ncabinet\ncar\ncat\nceiling\nchair\ncloth\ncomputer\ncow\ncup\ncurtain\ndog\ndoor\nfence\nfloor\nflower\nfood\ngrass\nground\nhorse\nkeyboard\nlight\nmotorbike\nmountain\nmouse\nperson\nplate\nplatform\npottedplant\nroad\nrock\nsheep\nshelves\nsidewalk\nsign\nsky\nsnow\nsofa\ntable\ntrack\ntrain\ntree\ntruck\ntvmonitor\nwall\nwater\nwindow\nwood "
  },
  {
    "path": "configs/cls_voc20.txt",
    "content": "aeroplane\nbicycle\nbird\nship\nbottle\nbus\ncar\ncat\nchair\ncow\ntable\ndog\nhorse\nmotorbike\nperson, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket\npottedplant\nsheep\nsofa\ntrain\ntelevision monitor, tv monitor, monitor, television, screen"
  },
  {
    "path": "configs/cls_voc21.txt",
    "content": "sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence\naeroplane\nbicycle\nbird\nship\nbottle\nbus\ncar\ncat\nchair\ncow\ntable\ndog\nhorse\nmotorbike\nperson, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket\npottedplant\nsheep\nsofa\ntrain\ntelevision monitor, tv monitor, monitor, television, screen"
  },
  {
    "path": "custom_datasets.py",
    "content": "import os.path as osp\nimport mmengine.fileio as fileio\n\nfrom mmseg.registry import DATASETS\nfrom mmseg.datasets import BaseSegDataset\n\n@DATASETS.register_module()\nclass PascalVOC20Dataset(BaseSegDataset):\n    \"\"\"Pascal VOC dataset.\n\n    Args:\n        split (str): Split txt file for Pascal VOC.\n    \"\"\"\n    METAINFO = dict(\n        classes=('aeroplane', 'bicycle', 'bird', 'boat',\n                 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',\n                 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',\n                 'sofa', 'train', 'tvmonitor'),\n        palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192],\n                 [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64],\n                 [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],\n                 [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],\n                 [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],\n                 [0, 64, 128]])\n\n    def __init__(self,\n                 ann_file,\n                 img_suffix='.jpg',\n                 seg_map_suffix='.png',\n                 reduce_zero_label=True,\n                 **kwargs) -> None:\n        super().__init__(\n            img_suffix=img_suffix,\n            seg_map_suffix=seg_map_suffix,\n            reduce_zero_label=reduce_zero_label,\n            ann_file=ann_file,\n            **kwargs)\n        assert fileio.exists(self.data_prefix['img_path'],\n                             self.backend_args) and osp.isfile(self.ann_file)\n\n@DATASETS.register_module()\nclass COCOObjectDataset(BaseSegDataset):\n    \"\"\"\n    Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT)\n    COCO-Object dataset.\n    1 bg class + first 80 classes from the COCO-Stuff dataset.\n    \"\"\"\n\n    METAINFO = dict(\n    \n    classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',\n               'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',\n               'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',\n               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',\n               'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',\n               'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',\n               'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',\n               'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',\n               'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),\n\n    palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],\n               [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],\n               [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],\n               [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],\n               [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],\n               [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],\n               [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],\n               [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],\n               [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],\n               [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],\n               [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],\n               [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]])\n\n    def __init__(self, **kwargs):\n        super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)\n\n@DATASETS.register_module()\nclass PascalContext60Dataset(BaseSegDataset):\n    METAINFO = dict(\n        classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',\n                 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',\n                 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',\n                 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',\n                 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',\n                 'horse', 'keyboard', 'light', 'motorbike', 'mountain',\n                 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',\n                 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',\n                 'sofa', 'table', 'track', 'train', 'tree', 'truck',\n                 'tvmonitor', 'wall', 'water', 'window', 'wood'),\n        palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],\n                 [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n                 [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n                 [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n                 [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n                 [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n                 [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n                 [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n                 [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n                 [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n                 [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n                 [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n                 [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n                 [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n                 [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])\n\n    def __init__(self,\n                 ann_file: str,\n                 img_suffix='.jpg',\n                 seg_map_suffix='.png',\n                 **kwargs) -> None:\n        super().__init__(\n            img_suffix=img_suffix,\n            seg_map_suffix=seg_map_suffix,\n            ann_file=ann_file,\n            reduce_zero_label=False,\n            **kwargs)\n\n\n@DATASETS.register_module()\nclass PascalContext59Dataset(BaseSegDataset):\n    METAINFO = dict(\n        classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',\n                 'bird', 'boat', 'book', 'bottle', 'building', 'bus',\n                 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',\n                 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',\n                 'floor', 'flower', 'food', 'grass', 'ground', 'horse',\n                 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',\n                 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',\n                 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',\n                 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',\n                 'wall', 'water', 'window', 'wood'),\n        palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],\n                 [120, 120, 80], [140, 140, 140], [204, 5, 255],\n                 [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n                 [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n                 [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n                 [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n                 [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n                 [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n                 [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n                 [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n                 [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n                 [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n                 [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n                 [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n                 [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])\n\n    def __init__(self,\n                 ann_file: str,\n                 img_suffix='.jpg',\n                 seg_map_suffix='.png',\n                 reduce_zero_label=True,\n                 **kwargs):\n        super().__init__(\n            img_suffix=img_suffix,\n            seg_map_suffix=seg_map_suffix,\n            ann_file=ann_file,\n            reduce_zero_label=reduce_zero_label,\n            **kwargs)"
  },
  {
    "path": "datasets/cvt_coco_object.py",
    "content": "# ------------------------------------------------------------------------------\n# GroupViT (https://github.com/NVlabs/GroupViT)\n# Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.\n# ------------------------------------------------------------------------------\n\nimport argparse\nimport os.path as osp\nimport shutil\nfrom functools import partial\nfrom glob import glob\n\nimport mmcv\nimport numpy as np\nfrom PIL import Image\n\nCOCO_LEN = 123287\n\nclsID_to_trID = {\n    0: 0,\n    1: 1,\n    2: 2,\n    3: 3,\n    4: 4,\n    5: 5,\n    6: 6,\n    7: 7,\n    8: 8,\n    9: 9,\n    10: 10,\n    12: 11,\n    13: 12,\n    14: 13,\n    15: 14,\n    16: 15,\n    17: 16,\n    18: 17,\n    19: 18,\n    20: 19,\n    21: 20,\n    22: 21,\n    23: 22,\n    24: 23,\n    26: 24,\n    27: 25,\n    30: 26,\n    31: 27,\n    32: 28,\n    33: 29,\n    34: 30,\n    35: 31,\n    36: 32,\n    37: 33,\n    38: 34,\n    39: 35,\n    40: 36,\n    41: 37,\n    42: 38,\n    43: 39,\n    45: 40,\n    46: 41,\n    47: 42,\n    48: 43,\n    49: 44,\n    50: 45,\n    51: 46,\n    52: 47,\n    53: 48,\n    54: 49,\n    55: 50,\n    56: 51,\n    57: 52,\n    58: 53,\n    59: 54,\n    60: 55,\n    61: 56,\n    62: 57,\n    63: 58,\n    64: 59,\n    66: 60,\n    69: 61,\n    71: 62,\n    72: 63,\n    73: 64,\n    74: 65,\n    75: 66,\n    76: 67,\n    77: 68,\n    78: 69,\n    79: 70,\n    80: 71,\n    81: 72,\n    83: 73,\n    84: 74,\n    85: 75,\n    86: 76,\n    87: 77,\n    88: 78,\n    89: 79,\n    91: 80,\n    92: 81,\n    93: 82,\n    94: 83,\n    95: 84,\n    96: 85,\n    97: 86,\n    98: 87,\n    99: 88,\n    100: 89,\n    101: 90,\n    102: 91,\n    103: 92,\n    104: 93,\n    105: 94,\n    106: 95,\n    107: 96,\n    108: 97,\n    109: 98,\n    110: 99,\n    111: 100,\n    112: 101,\n    113: 102,\n    114: 103,\n    115: 104,\n    116: 105,\n    117: 106,\n    118: 107,\n    119: 108,\n    120: 109,\n    121: 110,\n    122: 111,\n    123: 112,\n    124: 113,\n    125: 114,\n    126: 115,\n    127: 116,\n    128: 117,\n    129: 118,\n    130: 119,\n    131: 120,\n    132: 121,\n    133: 122,\n    134: 123,\n    135: 124,\n    136: 125,\n    137: 126,\n    138: 127,\n    139: 128,\n    140: 129,\n    141: 130,\n    142: 131,\n    143: 132,\n    144: 133,\n    145: 134,\n    146: 135,\n    147: 136,\n    148: 137,\n    149: 138,\n    150: 139,\n    151: 140,\n    152: 141,\n    153: 142,\n    154: 143,\n    155: 144,\n    156: 145,\n    157: 146,\n    158: 147,\n    159: 148,\n    160: 149,\n    161: 150,\n    162: 151,\n    163: 152,\n    164: 153,\n    165: 154,\n    166: 155,\n    167: 156,\n    168: 157,\n    169: 158,\n    170: 159,\n    171: 160,\n    172: 161,\n    173: 162,\n    174: 163,\n    175: 164,\n    176: 165,\n    177: 166,\n    178: 167,\n    179: 168,\n    180: 169,\n    181: 170,\n    255: 255\n}\n\n# set to background\nfor k, v in clsID_to_trID.items():\n    clsID_to_trID[k] = v + 1\n    if k > 90:\n        clsID_to_trID[k] = 0\n\n\ndef convert_to_trainID(maskpath, out_mask_dir, is_train):\n    mask = np.array(Image.open(maskpath))\n    mask_copy = mask.copy()\n    for clsID, trID in clsID_to_trID.items():\n        mask_copy[mask == clsID] = trID\n    seg_filename = osp.join(\n        out_mask_dir, 'train2017',\n        osp.basename(maskpath).split('.')[0] +\n        '_instanceTrainIds.png') if is_train else osp.join(\n            out_mask_dir, 'val2017',\n            osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png')\n    Image.fromarray(mask_copy).save(seg_filename, 'PNG')\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\\\n        'Convert COCO Stuff 164k annotations to COCO Objects')  # noqa\n    parser.add_argument('coco_path', help='coco stuff path')\n    parser.add_argument('-o', '--out_dir', help='output path')\n    parser.add_argument(\n        '--nproc', default=16, type=int, help='number of process')\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = parse_args()\n    coco_path = args.coco_path\n    nproc = args.nproc\n\n    out_dir = args.out_dir or coco_path\n    out_img_dir = osp.join(out_dir, 'images')\n    out_mask_dir = osp.join(out_dir, 'annotations')\n\n    mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))\n    mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))\n\n    if out_dir != coco_path:\n        shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)\n\n    train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))\n    train_list = [file for file in train_list if 'TrainIds' not in file]\n    test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))\n    test_list = [file for file in test_list if 'TrainIds' not in file]\n    assert (len(train_list) +\n            len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(\n                len(train_list), len(test_list))\n\n    if args.nproc > 1:\n        mmcv.track_parallel_progress(\n            partial(\n                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),\n            train_list,\n            nproc=nproc)\n        mmcv.track_parallel_progress(\n            partial(\n                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),\n            test_list,\n            nproc=nproc)\n    else:\n        mmcv.track_progress(\n            partial(\n                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),\n            train_list)\n        mmcv.track_progress(\n            partial(\n                convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),\n            test_list)\n\n    print('Done!')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "dist_test.sh",
    "content": "CONFIG=$1\n\nWORK_DIR=${WORK_DIR:-\"./work_logs\"}\nGPUS=${GPUS:-4}\nNNODES=${NNODES:-1}\nNODE_RANK=${NODE_RANK:-0}\nPORT=${PORT:-29500}\nMASTER_ADDR=${MASTER_ADDR:-\"127.0.0.1\"}\n\nPYTHONPATH=\"$(dirname $0)/..\":$PYTHONPATH \\\npython -m torch.distributed.launch \\\n    --nnodes=$NNODES \\\n    --node_rank=$NODE_RANK \\\n    --master_addr=$MASTER_ADDR \\\n    --nproc_per_node=$GPUS \\\n    --master_port=$PORT \\\n    $(dirname \"$0\")/eval.py \\\n    --config $CONFIG \\\n    --work-dir $WORK_DIR \\\n    --launcher pytorch \\\n    ${@:4}"
  },
  {
    "path": "eval.py",
    "content": "import os\nimport argparse\nimport clip_segmentor\nimport custom_datasets\n\nfrom mmengine.config import Config\nfrom mmengine.runner import Runner\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='SCLIP evaluation with MMSeg')\n    parser.add_argument('--config', default='')\n    parser.add_argument('--work-dir', default='./work_logs/')\n    parser.add_argument(\n        '--show', action='store_true', help='show prediction results')\n    parser.add_argument(\n        '--show_dir',\n        default='',\n        help='directory to save visualizaion images')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch', 'slurm', 'mpi'],\n        default='none',\n        help='job launcher')\n    # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`\n    # will pass the `--local-rank` parameter to `tools/train.py` instead\n    # of `--local_rank`.\n    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)\n    args = parser.parse_args()\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    return args\n\ndef trigger_visualization_hook(cfg, args):\n    default_hooks = cfg.default_hooks\n    if 'visualization' in default_hooks:\n        visualization_hook = default_hooks['visualization']\n        # Turn on visualization\n        visualization_hook['draw'] = True\n        if args.show:\n            visualization_hook['show'] = True\n            visualization_hook['wait_time'] = args.wait_time\n        if args.show_dir:\n            visualizer = cfg.visualizer\n            visualizer['save_dir'] = args.show_dir\n    else:\n        raise RuntimeError(\n            'VisualizationHook must be included in default_hooks.'\n            'refer to usage '\n            '\"visualization=dict(type=\\'VisualizationHook\\')\"')\n\n    return cfg\n\ndef main():\n    args = parse_args()\n\n    cfg = Config.fromfile(args.config)\n    cfg.launcher = args.launcher\n    cfg.work_dir = args.work_dir\n\n    runner = Runner.from_cfg(cfg)\n    runner.test()\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "pamr.py",
    "content": "# Copyright 2020 TU Darmstadt\n# Licnese: Apache 2.0 License.\n# https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nfrom functools import partial\n\n#\n# Helper modules\n#\nclass LocalAffinity(nn.Module):\n\n    def __init__(self, dilations=[1]):\n        super(LocalAffinity, self).__init__()\n        self.dilations = dilations\n        weight = self._init_aff()\n        self.register_buffer('kernel', weight)\n\n    def _init_aff(self):\n        # initialising the shift kernel\n        weight = torch.zeros(8, 1, 3, 3)\n\n        for i in range(weight.size(0)):\n            weight[i, 0, 1, 1] = 1\n\n        weight[0, 0, 0, 0] = -1\n        weight[1, 0, 0, 1] = -1\n        weight[2, 0, 0, 2] = -1\n\n        weight[3, 0, 1, 0] = -1\n        weight[4, 0, 1, 2] = -1\n\n        weight[5, 0, 2, 0] = -1\n        weight[6, 0, 2, 1] = -1\n        weight[7, 0, 2, 2] = -1\n\n        self.weight_check = weight.clone()\n\n        return weight\n\n    def forward(self, x):\n\n        self.weight_check = self.weight_check.type_as(x)\n        assert torch.all(self.weight_check.eq(self.kernel))\n\n        B,K,H,W = x.size()\n        x = x.view(B*K,1,H,W)\n\n        x_affs = []\n        for d in self.dilations:\n            x_pad = F.pad(x, [d]*4, mode='replicate')\n            x_aff = F.conv2d(x_pad, self.kernel, dilation=d)\n            x_affs.append(x_aff)\n\n        x_aff = torch.cat(x_affs, 1)\n        return x_aff.view(B,K,-1,H,W)\n\nclass LocalAffinityCopy(LocalAffinity):\n\n    def _init_aff(self):\n        # initialising the shift kernel\n        weight = torch.zeros(8, 1, 3, 3)\n\n        weight[0, 0, 0, 0] = 1\n        weight[1, 0, 0, 1] = 1\n        weight[2, 0, 0, 2] = 1\n\n        weight[3, 0, 1, 0] = 1\n        weight[4, 0, 1, 2] = 1\n\n        weight[5, 0, 2, 0] = 1\n        weight[6, 0, 2, 1] = 1\n        weight[7, 0, 2, 2] = 1\n\n        self.weight_check = weight.clone()\n        return weight\n\nclass LocalStDev(LocalAffinity):\n\n    def _init_aff(self):\n        weight = torch.zeros(9, 1, 3, 3)\n        weight.zero_()\n\n        weight[0, 0, 0, 0] = 1\n        weight[1, 0, 0, 1] = 1\n        weight[2, 0, 0, 2] = 1\n\n        weight[3, 0, 1, 0] = 1\n        weight[4, 0, 1, 1] = 1\n        weight[5, 0, 1, 2] = 1\n\n        weight[6, 0, 2, 0] = 1\n        weight[7, 0, 2, 1] = 1\n        weight[8, 0, 2, 2] = 1\n\n        self.weight_check = weight.clone()\n        return weight\n\n    def forward(self, x):\n        # returns (B,K,P,H,W), where P is the number\n        # of locations\n        x = super(LocalStDev, self).forward(x)\n\n        return x.std(2, keepdim=True)\n\nclass LocalAffinityAbs(LocalAffinity):\n\n    def forward(self, x):\n        x = super(LocalAffinityAbs, self).forward(x)\n        return torch.abs(x)\n\n#\n# PAMR module\n#\nclass PAMR(nn.Module):\n\n    def __init__(self, num_iter=1, dilations=[1]):\n        super(PAMR, self).__init__()\n\n        self.num_iter = num_iter\n        self.aff_x = LocalAffinityAbs(dilations)\n        self.aff_m = LocalAffinityCopy(dilations)\n        self.aff_std = LocalStDev(dilations)\n\n    def forward(self, x, mask):\n        mask = F.interpolate(mask, size=x.size()[-2:], mode=\"bilinear\", align_corners=True)\n\n        # x: [BxKxHxW]\n        # mask: [BxCxHxW]\n        B,K,H,W = x.size()\n        _,C,_,_ = mask.size()\n\n        x_std = self.aff_std(x)\n\n        x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)\n        x = x.mean(1, keepdim=True)\n        x = F.softmax(x, 2)\n\n        for _ in range(self.num_iter):\n            m = self.aff_m(mask)  # [BxCxPxHxW]\n            mask = (m * x).sum(2)\n\n        # xvals: [BxCxHxW]\n        return mask"
  },
  {
    "path": "prompts/imagenet_template.py",
    "content": "\nimagenet_classnames = [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\",\n                        \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\",\n                        \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\",\n                        \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\",\n                        \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\",\n                        \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\",\n                        \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\",\n                        \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\",\n                        \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\",\n                        \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\",\n                        \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\",\n                        \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\",\n                        \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\",\n                        \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\",\n                        \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\",\n                        \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\",\n                        \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\",\n                        \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\",\n                        \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\",\n                        \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\",\n                        \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\",\n                        \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\",\n                        \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\",\n                        \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\",\n                        \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\",\n                        \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\",\n                        \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\",\n                        \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\",\n                        \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\",\n                        \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\",\n                        \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\",\n                        \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\",\n                        \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\",\n                        \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\",\n                        \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\",\n                        \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\",\n                        \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\",\n                        \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\",\n                        \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\",\n                        \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\",\n                        \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\",\n                        \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\",\n                        \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\",\n                        \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\",\n                        \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\",\n                        \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\",\n                        \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\",\n                        \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\",\n                        \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\",\n                        \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\",\n                        \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\",\n                        \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\",\n                        \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\",\n                        \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\",\n                        \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\",\n                        \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\",\n                        \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\",\n                        \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\",\n                        \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\",\n                        \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\",\n                        \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\",\n                        \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\",\n                        \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\",\n                        \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\",\n                        \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\",\n                        \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\",\n                        \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\",\n                        \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\",\n                        \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\",\n                        \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\",\n                        \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\",\n                        \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\",\n                        \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\",\n                        \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\",\n                        \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\",\n                        \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\",\n                        \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\",\n                        \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\",\n                        \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\",\n                        \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\",\n                        \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\",\n                        \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\",\n                        \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\",\n                        \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\",\n                        \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\",\n                        \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\",\n                        \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\",\n                        \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\",\n                        \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\",\n                        \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\",\n                        \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\",\n                        \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\",\n                        \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\",\n                        \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\",\n                        \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\",\n                        \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\",\n                        \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\",\n                        \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\",\n                        \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\",\n                        \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\",\n                        \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\",\n                        \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\",\n                        \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\",\n                        \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\",\n                        \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\",\n                        \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\",\n                        \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\",\n                        \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\",\n                        \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\",\n                        \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\",\n                        \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\",\n                        \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\",\n                        \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\",\n                        \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\",\n                        \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\",\n                        \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\",\n                        \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\",\n                        \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\",\n                        \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\",\n                        \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\",\n                        \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\",\n                        \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\",\n                        \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\",\n                        \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\",\n                        \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\",\n                        \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\",\n                        \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\",\n                        \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\",\n                        \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\",\n                        \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\",\n                        \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\",\n                        \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\",\n                        \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\",\n                        \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\",\n                        \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\",\n                        \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\",\n                        \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\",\n                        \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\",\n                        \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\",\n                        \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\",\n                        \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\",\n                        \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\",\n                        \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\",\n                        \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\",\n                        \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\",\n                        \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\",\n                        \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\",\n                        \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\",\n                        \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\",\n                        \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\",\n                        \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\",\n                        \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\",\n                        \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\",\n                        \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\",\n                        \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\",\n                        \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\",\n                        \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\",\n                        \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\",\n                        \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\",\n                        \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\",\n                        \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\",\n                        \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\",\n                        \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\",\n                        \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\",\n                        \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n\n\nopenai_imagenet_template = [\n    lambda c: f'a bad photo of a {c}.',\n    lambda c: f'a photo of many {c}.',\n    lambda c: f'a sculpture of a {c}.',\n    lambda c: f'a photo of the hard to see {c}.',\n    lambda c: f'a low resolution photo of the {c}.',\n    lambda c: f'a rendering of a {c}.',\n    lambda c: f'graffiti of a {c}.',\n    lambda c: f'a bad photo of the {c}.',\n    lambda c: f'a cropped photo of the {c}.',\n    lambda c: f'a tattoo of a {c}.',\n    lambda c: f'the embroidered {c}.',\n    lambda c: f'a photo of a hard to see {c}.',\n    lambda c: f'a bright photo of a {c}.',\n    lambda c: f'a photo of a clean {c}.',\n    lambda c: f'a photo of a dirty {c}.',\n    lambda c: f'a dark photo of the {c}.',\n    lambda c: f'a drawing of a {c}.',\n    lambda c: f'a photo of my {c}.',\n    lambda c: f'the plastic {c}.',\n    lambda c: f'a photo of the cool {c}.',\n    lambda c: f'a close-up photo of a {c}.',\n    lambda c: f'a black and white photo of the {c}.',\n    lambda c: f'a painting of the {c}.',\n    lambda c: f'a painting of a {c}.',\n    lambda c: f'a pixelated photo of the {c}.',\n    lambda c: f'a sculpture of the {c}.',\n    lambda c: f'a bright photo of the {c}.',\n    lambda c: f'a cropped photo of a {c}.',\n    lambda c: f'a plastic {c}.',\n    lambda c: f'a photo of the dirty {c}.',\n    lambda c: f'a jpeg corrupted photo of a {c}.',\n    lambda c: f'a blurry photo of the {c}.',\n    lambda c: f'a photo of the {c}.',\n    lambda c: f'a good photo of the {c}.',\n    lambda c: f'a rendering of the {c}.',\n    lambda c: f'a {c} in a video game.',\n    lambda c: f'a photo of one {c}.',\n    lambda c: f'a doodle of a {c}.',\n    lambda c: f'a close-up photo of the {c}.',\n    lambda c: f'a photo of a {c}.',\n    lambda c: f'the origami {c}.',\n    lambda c: f'the {c} in a video game.',\n    lambda c: f'a sketch of a {c}.',\n    lambda c: f'a doodle of the {c}.',\n    lambda c: f'a origami {c}.',\n    lambda c: f'a low resolution photo of a {c}.',\n    lambda c: f'the toy {c}.',\n    lambda c: f'a rendition of the {c}.',\n    lambda c: f'a photo of the clean {c}.',\n    lambda c: f'a photo of a large {c}.',\n    lambda c: f'a rendition of a {c}.',\n    lambda c: f'a photo of a nice {c}.',\n    lambda c: f'a photo of a weird {c}.',\n    lambda c: f'a blurry photo of a {c}.',\n    lambda c: f'a cartoon {c}.',\n    lambda c: f'art of a {c}.',\n    lambda c: f'a sketch of the {c}.',\n    lambda c: f'a embroidered {c}.',\n    lambda c: f'a pixelated photo of a {c}.',\n    lambda c: f'itap of the {c}.',\n    lambda c: f'a jpeg corrupted photo of the {c}.',\n    lambda c: f'a good photo of a {c}.',\n    lambda c: f'a plushie {c}.',\n    lambda c: f'a photo of the nice {c}.',\n    lambda c: f'a photo of the small {c}.',\n    lambda c: f'a photo of the weird {c}.',\n    lambda c: f'the cartoon {c}.',\n    lambda c: f'art of the {c}.',\n    lambda c: f'a drawing of the {c}.',\n    lambda c: f'a photo of the large {c}.',\n    lambda c: f'a black and white photo of a {c}.',\n    lambda c: f'the plushie {c}.',\n    lambda c: f'a dark photo of a {c}.',\n    lambda c: f'itap of a {c}.',\n    lambda c: f'graffiti of the {c}.',\n    lambda c: f'a toy {c}.',\n    lambda c: f'itap of my {c}.',\n    lambda c: f'a photo of a cool {c}.',\n    lambda c: f'a photo of a small {c}.',\n    lambda c: f'a tattoo of the {c}.',\n]"
  }
]