[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\r\n                           Version 2.0, January 2004\r\n                        http://www.apache.org/licenses/\r\n\r\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\r\n\r\n   1. Definitions.\r\n\r\n      \"License\" shall mean the terms and conditions for use, reproduction,\r\n      and distribution as defined by Sections 1 through 9 of this document.\r\n\r\n      \"Licensor\" shall mean the copyright owner or entity authorized by\r\n      the copyright owner that is granting the License.\r\n\r\n      \"Legal Entity\" shall mean the union of the acting entity and all\r\n      other entities that control, are controlled by, or are under common\r\n      control with that entity. For the purposes of this definition,\r\n      \"control\" means (i) the power, direct or indirect, to cause the\r\n      direction or management of such entity, whether by contract or\r\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\r\n      outstanding shares, or (iii) beneficial ownership of such entity.\r\n\r\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\r\n      exercising permissions granted by this License.\r\n\r\n      \"Source\" form shall mean the preferred form for making modifications,\r\n      including but not limited to software source code, documentation\r\n      source, and configuration files.\r\n\r\n      \"Object\" form shall mean any form resulting from mechanical\r\n      transformation or translation of a Source form, including but\r\n      not limited to compiled object code, generated documentation,\r\n      and conversions to other media types.\r\n\r\n      \"Work\" shall mean the work of authorship, whether in Source or\r\n      Object form, made available under the License, as indicated by a\r\n      copyright notice that is included in or attached to the work\r\n      (an example is provided in the Appendix below).\r\n\r\n      \"Derivative Works\" shall mean any work, whether in Source or Object\r\n      form, that is based on (or derived from) the Work and for which the\r\n      editorial revisions, annotations, elaborations, or other modifications\r\n      represent, as a whole, an original work of authorship. For the purposes\r\n      of this License, Derivative Works shall not include works that remain\r\n      separable from, or merely link (or bind by name) to the interfaces of,\r\n      the Work and Derivative Works thereof.\r\n\r\n      \"Contribution\" shall mean any work of authorship, including\r\n      the original version of the Work and any modifications or additions\r\n      to that Work or Derivative Works thereof, that is intentionally\r\n      submitted to Licensor for inclusion in the Work by the copyright owner\r\n      or by an individual or Legal Entity authorized to submit on behalf of\r\n      the copyright owner. For the purposes of this definition, \"submitted\"\r\n      means any form of electronic, verbal, or written communication sent\r\n      to the Licensor or its representatives, including but not limited to\r\n      communication on electronic mailing lists, source code control systems,\r\n      and issue tracking systems that are managed by, or on behalf of, the\r\n      Licensor for the purpose of discussing and improving the Work, but\r\n      excluding communication that is conspicuously marked or otherwise\r\n      designated in writing by the copyright owner as \"Not a Contribution.\"\r\n\r\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\r\n      on behalf of whom a Contribution has been received by Licensor and\r\n      subsequently incorporated within the Work.\r\n\r\n   2. Grant of Copyright License. Subject to the terms and conditions of\r\n      this License, each Contributor hereby grants to You a perpetual,\r\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\r\n      copyright license to reproduce, prepare Derivative Works of,\r\n      publicly display, publicly perform, sublicense, and distribute the\r\n      Work and such Derivative Works in Source or Object form.\r\n\r\n   3. Grant of Patent License. Subject to the terms and conditions of\r\n      this License, each Contributor hereby grants to You a perpetual,\r\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\r\n      (except as stated in this section) patent license to make, have made,\r\n      use, offer to sell, sell, import, and otherwise transfer the Work,\r\n      where such license applies only to those patent claims licensable\r\n      by such Contributor that are necessarily infringed by their\r\n      Contribution(s) alone or by combination of their Contribution(s)\r\n      with the Work to which such Contribution(s) was submitted. If You\r\n      institute patent litigation against any entity (including a\r\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\r\n      or a Contribution incorporated within the Work constitutes direct\r\n      or contributory patent infringement, then any patent licenses\r\n      granted to You under this License for that Work shall terminate\r\n      as of the date such litigation is filed.\r\n\r\n   4. Redistribution. You may reproduce and distribute copies of the\r\n      Work or Derivative Works thereof in any medium, with or without\r\n      modifications, and in Source or Object form, provided that You\r\n      meet the following conditions:\r\n\r\n      (a) You must give any other recipients of the Work or\r\n          Derivative Works a copy of this License; and\r\n\r\n      (b) You must cause any modified files to carry prominent notices\r\n          stating that You changed the files; and\r\n\r\n      (c) You must retain, in the Source form of any Derivative Works\r\n          that You distribute, all copyright, patent, trademark, and\r\n          attribution notices from the Source form of the Work,\r\n          excluding those notices that do not pertain to any part of\r\n          the Derivative Works; and\r\n\r\n      (d) If the Work includes a \"NOTICE\" text file as part of its\r\n          distribution, then any Derivative Works that You distribute must\r\n          include a readable copy of the attribution notices contained\r\n          within such NOTICE file, excluding those notices that do not\r\n          pertain to any part of the Derivative Works, in at least one\r\n          of the following places: within a NOTICE text file distributed\r\n          as part of the Derivative Works; within the Source form or\r\n          documentation, if provided along with the Derivative Works; or,\r\n          within a display generated by the Derivative Works, if and\r\n          wherever such third-party notices normally appear. The contents\r\n          of the NOTICE file are for informational purposes only and\r\n          do not modify the License. You may add Your own attribution\r\n          notices within Derivative Works that You distribute, alongside\r\n          or as an addendum to the NOTICE text from the Work, provided\r\n          that such additional attribution notices cannot be construed\r\n          as modifying the License.\r\n\r\n      You may add Your own copyright statement to Your modifications and\r\n      may provide additional or different license terms and conditions\r\n      for use, reproduction, or distribution of Your modifications, or\r\n      for any such Derivative Works as a whole, provided Your use,\r\n      reproduction, and distribution of the Work otherwise complies with\r\n      the conditions stated in this License.\r\n\r\n   5. Submission of Contributions. Unless You explicitly state otherwise,\r\n      any Contribution intentionally submitted for inclusion in the Work\r\n      by You to the Licensor shall be under the terms and conditions of\r\n      this License, without any additional terms or conditions.\r\n      Notwithstanding the above, nothing herein shall supersede or modify\r\n      the terms of any separate license agreement you may have executed\r\n      with Licensor regarding such Contributions.\r\n\r\n   6. Trademarks. This License does not grant permission to use the trade\r\n      names, trademarks, service marks, or product names of the Licensor,\r\n      except as required for reasonable and customary use in describing the\r\n      origin of the Work and reproducing the content of the NOTICE file.\r\n\r\n   7. Disclaimer of Warranty. Unless required by applicable law or\r\n      agreed to in writing, Licensor provides the Work (and each\r\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\r\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\r\n      implied, including, without limitation, any warranties or conditions\r\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\r\n      PARTICULAR PURPOSE. You are solely responsible for determining the\r\n      appropriateness of using or redistributing the Work and assume any\r\n      risks associated with Your exercise of permissions under this License.\r\n\r\n   8. Limitation of Liability. In no event and under no legal theory,\r\n      whether in tort (including negligence), contract, or otherwise,\r\n      unless required by applicable law (such as deliberate and grossly\r\n      negligent acts) or agreed to in writing, shall any Contributor be\r\n      liable to You for damages, including any direct, indirect, special,\r\n      incidental, or consequential damages of any character arising as a\r\n      result of this License or out of the use or inability to use the\r\n      Work (including but not limited to damages for loss of goodwill,\r\n      work stoppage, computer failure or malfunction, or any and all\r\n      other commercial damages or losses), even if such Contributor\r\n      has been advised of the possibility of such damages.\r\n\r\n   9. Accepting Warranty or Additional Liability. While redistributing\r\n      the Work or Derivative Works thereof, You may choose to offer,\r\n      and charge a fee for, acceptance of support, warranty, indemnity,\r\n      or other liability obligations and/or rights consistent with this\r\n      License. However, in accepting such obligations, You may act only\r\n      on Your own behalf and on Your sole responsibility, not on behalf\r\n      of any other Contributor, and only if You agree to indemnify,\r\n      defend, and hold each Contributor harmless for any liability\r\n      incurred by, or claims asserted against, such Contributor by reason\r\n      of your accepting any such warranty or additional liability.\r\n\r\n   END OF TERMS AND CONDITIONS\r\n\r\n   APPENDIX: How to apply the Apache License to your work.\r\n\r\n      To apply the Apache License to your work, attach the following\r\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\r\n      replaced with your own identifying information. (Don't include\r\n      the brackets!)  The text should be enclosed in the appropriate\r\n      comment syntax for the file format. We also recommend that a\r\n      file or class name and description of purpose be included on the\r\n      same \"printed page\" as the copyright notice for easier\r\n      identification within third-party archives.\r\n\r\n   Copyright [yyyy] [name of copyright owner]\r\n\r\n   Licensed under the Apache License, Version 2.0 (the \"License\");\r\n   you may not use this file except in compliance with the License.\r\n   You may obtain a copy of the License at\r\n\r\n       http://www.apache.org/licenses/LICENSE-2.0\r\n\r\n   Unless required by applicable law or agreed to in writing, software\r\n   distributed under the License is distributed on an \"AS IS\" BASIS,\r\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n   See the License for the specific language governing permissions and\r\n   limitations under the License.\r\n"
  },
  {
    "path": "README.md",
    "content": "# ComfyUI-PuLID-Flux-Enhanced\nadapted from https://github.com/balazik/ComfyUI-PuLID-Flux\n\nworkflow: see example flux_pulid_multi.json\n\n## oct.7 2025\nFormally discontinued.\nYou guys may just use i2i models like flux kontext/qwen image edit, they are just doing same thing or doing better than Pulid.\n\n## update oct.28 2024\nAdd an optional prior image input for the node. When using the train_weight method, the prior image will act as the main id image, which will lead the other id images to sum up to an optimized id embedding.\n\nThis prior was randomly choosen previously, now we can assign it.\n\nLeaving the prior image input empty is OK just as previous.\n\nPlease choose the best id image in your mind as the prior, or just experiment around and see what happens.\n![oct28](https://github.com/user-attachments/assets/6a481cd9-2836-4f6f-9ad5-7458356c332a)\n\n## new features\n### common fusion methods for multi-image input\nmean(official), concat, max...etc\n\n### some further experimental fusion methods.\nusing the norm of the conditions to weight them\n\nusing the max norm token among images\n\na novel very fast embeddings self-training methods(explained here: https://github.com/balazik/ComfyUI-PuLID-Flux/issues/28)\n\n### switch between using gray image (official) and rgb.\nin some cases, using gray image will bring detail loss\n\n![2024-10-12_204047](https://github.com/user-attachments/assets/0ae96170-2eff-44e9-a53a-6a7447dbc0f1)\n\n## tricks make your generation better\n### fusion method leverages many id images to enhance fidelity\n1. Besides mean fusion, you can try max or max_token, which can boost some major feature of a face (like large eyes, special nose or sth). it can go distortion beyond fidelity though.\n2. With train_weight method, you can train with less than 2000 steps to make a deeper fusion than the non-training methods. Be aware too many training steps will make the training crash to the prior image.\n\n### additional notes\n1. Flux is a high capacity base model, it even can cognize the input image in some super human way. \nfor example, you can resize your high quality input image with lanczos method rather than nearest area or billinear. you get finer texture. Keep in mind that taking care of your input image is the thing when the base model is strong.\n2. The best pulid weight is around 0.8-0.95 for flux pulid 0.9.0. 1.0 is not good. For 0.9.1, it's higher towards around 0.9-1.0. Nonetheless the 0.9.1 is not always better than 0.9.0.\n3. The base model is flux-dev or its finetuning, and the precision does mean the thing. fp16 should always be sound. fp8 is OK. I won't recommend gguf or nf4 things.\n4. Some of the finetuned flux dev model may have strong bias. for example, it may sway the faces to a certain human race.\n5. Euler simple is always working. Euler beta give you higher quality especially if your input image is somewhat low quality.\n6. If you wanna use 3rd party flux-d weight, better to use a merged one or with a lora weight, rather than a finetuned one. Full finetuning can hurt the connection between pulid and original flux-d base model. You can test by yourself though. \n\n## basic notes for common users\nThis is an experimental node. It can give enhanced result but I'm not promising basic instructions for users who barely know about python developing or AI developing.\n\nPlease follow the comfyui instructions or https://github.com/balazik/ComfyUI-PuLID-Flux to enable usage.\n\nIf you are just using SDXL pulid, you can use https://github.com/cubiq/PuLID_ComfyUI. Some of the installation instructions there may also help.\n"
  },
  {
    "path": "__init__.py",
    "content": "from .pulidflux import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS\r\n\r\n__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']\r\n"
  },
  {
    "path": "encoders_flux.py",
    "content": "import math\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\n\r\n# FFN\r\ndef FeedForward(dim, mult=4):\r\n    inner_dim = int(dim * mult)\r\n    return nn.Sequential(\r\n        nn.LayerNorm(dim),\r\n        nn.Linear(dim, inner_dim, bias=False),\r\n        nn.GELU(),\r\n        nn.Linear(inner_dim, dim, bias=False),\r\n    )\r\n\r\n\r\ndef reshape_tensor(x, heads):\r\n    bs, length, width = x.shape\r\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\r\n    x = x.view(bs, length, heads, -1)\r\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\r\n    x = x.transpose(1, 2)\r\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\r\n    x = x.reshape(bs, heads, length, -1)\r\n    return x\r\n\r\n\r\nclass PerceiverAttentionCA(nn.Module):\r\n    def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):\r\n        super().__init__()\r\n        self.scale = dim_head ** -0.5\r\n        self.dim_head = dim_head\r\n        self.heads = heads\r\n        inner_dim = dim_head * heads\r\n\r\n        self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)\r\n        self.norm2 = nn.LayerNorm(dim)\r\n\r\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\r\n        self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)\r\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\r\n\r\n    def forward(self, x, latents):\r\n        \"\"\"\r\n        Args:\r\n            x (torch.Tensor): image features\r\n                shape (b, n1, D)\r\n            latent (torch.Tensor): latent features\r\n                shape (b, n2, D)\r\n        \"\"\"\r\n        x = self.norm1(x)\r\n        latents = self.norm2(latents)\r\n\r\n        b, seq_len, _ = latents.shape\r\n\r\n        q = self.to_q(latents)\r\n        k, v = self.to_kv(x).chunk(2, dim=-1)\r\n\r\n        q = reshape_tensor(q, self.heads)\r\n        k = reshape_tensor(k, self.heads)\r\n        v = reshape_tensor(v, self.heads)\r\n\r\n        # attention\r\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\r\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\r\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\r\n        out = weight @ v\r\n\r\n        out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)\r\n\r\n        return self.to_out(out)\r\n\r\n\r\nclass PerceiverAttention(nn.Module):\r\n    def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):\r\n        super().__init__()\r\n        self.scale = dim_head ** -0.5\r\n        self.dim_head = dim_head\r\n        self.heads = heads\r\n        inner_dim = dim_head * heads\r\n\r\n        self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)\r\n        self.norm2 = nn.LayerNorm(dim)\r\n\r\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\r\n        self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)\r\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\r\n\r\n    def forward(self, x, latents):\r\n        \"\"\"\r\n        Args:\r\n            x (torch.Tensor): image features\r\n                shape (b, n1, D)\r\n            latent (torch.Tensor): latent features\r\n                shape (b, n2, D)\r\n        \"\"\"\r\n        x = self.norm1(x)\r\n        latents = self.norm2(latents)\r\n\r\n        b, seq_len, _ = latents.shape\r\n\r\n        q = self.to_q(latents)\r\n        kv_input = torch.cat((x, latents), dim=-2)\r\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\r\n\r\n        q = reshape_tensor(q, self.heads)\r\n        k = reshape_tensor(k, self.heads)\r\n        v = reshape_tensor(v, self.heads)\r\n\r\n        # attention\r\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\r\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\r\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\r\n        out = weight @ v\r\n\r\n        out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)\r\n\r\n        return self.to_out(out)\r\n\r\n\r\nclass IDFormer(nn.Module):\r\n    \"\"\"\r\n    - perceiver resampler like arch (compared with previous MLP-like arch)\r\n    - we concat id embedding (generated by arcface) and query tokens as latents\r\n    - latents will attend each other and interact with vit features through cross-attention\r\n    - vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two\r\n      IDFormer layers\r\n    \"\"\"\r\n    def __init__(\r\n            self,\r\n            dim=1024,\r\n            depth=10,\r\n            dim_head=64,\r\n            heads=16,\r\n            num_id_token=5,\r\n            num_queries=32,\r\n            output_dim=2048,\r\n            ff_mult=4,\r\n    ):\r\n        super().__init__()\r\n\r\n        self.num_id_token = num_id_token\r\n        self.dim = dim\r\n        self.num_queries = num_queries\r\n        assert depth % 5 == 0\r\n        self.depth = depth // 5\r\n        scale = dim ** -0.5\r\n\r\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)\r\n        self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))\r\n\r\n        self.layers = nn.ModuleList([])\r\n        for _ in range(depth):\r\n            self.layers.append(\r\n                nn.ModuleList(\r\n                    [\r\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\r\n                        FeedForward(dim=dim, mult=ff_mult),\r\n                    ]\r\n                )\r\n            )\r\n\r\n        for i in range(5):\r\n            setattr(\r\n                self,\r\n                f'mapping_{i}',\r\n                nn.Sequential(\r\n                    nn.Linear(1024, 1024),\r\n                    nn.LayerNorm(1024),\r\n                    nn.LeakyReLU(),\r\n                    nn.Linear(1024, 1024),\r\n                    nn.LayerNorm(1024),\r\n                    nn.LeakyReLU(),\r\n                    nn.Linear(1024, dim),\r\n                ),\r\n            )\r\n\r\n        self.id_embedding_mapping = nn.Sequential(\r\n            nn.Linear(1280, 1024),\r\n            nn.LayerNorm(1024),\r\n            nn.LeakyReLU(),\r\n            nn.Linear(1024, 1024),\r\n            nn.LayerNorm(1024),\r\n            nn.LeakyReLU(),\r\n            nn.Linear(1024, dim * num_id_token),\r\n        )\r\n\r\n    def forward(self, x, y):\r\n\r\n        latents = self.latents.repeat(x.size(0), 1, 1)\r\n\r\n        x = self.id_embedding_mapping(x)\r\n        x = x.reshape(-1, self.num_id_token, self.dim)\r\n\r\n        latents = torch.cat((latents, x), dim=1)\r\n\r\n        for i in range(5):\r\n            vit_feature = getattr(self, f'mapping_{i}')(y[i])\r\n            ctx_feature = torch.cat((x, vit_feature), dim=1)\r\n            for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:\r\n                latents = attn(ctx_feature, latents) + latents\r\n                latents = ff(latents) + latents\r\n\r\n        latents = latents[:, :self.num_queries]\r\n        latents = latents @ self.proj_out\r\n        return latents\r\n"
  },
  {
    "path": "eva_clip/__init__.py",
    "content": "from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\r\nfrom .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms\r\nfrom .factory import list_models, add_model_config, get_model_config, load_checkpoint\r\nfrom .loss import ClipLoss\r\nfrom .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\\\r\n    convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype\r\nfrom .openai import load_openai_model, list_openai_models\r\nfrom .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\\\r\n    get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained\r\nfrom .tokenizer import SimpleTokenizer, tokenize\r\nfrom .transform import image_transform"
  },
  {
    "path": "eva_clip/constants.py",
    "content": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\r\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\r\n"
  },
  {
    "path": "eva_clip/eva_vit_model.py",
    "content": "# --------------------------------------------------------\r\n# Adapted from  https://github.com/microsoft/unilm/tree/master/beit\r\n# --------------------------------------------------------\r\nimport math\r\nimport os\r\nfrom functools import partial\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\ntry:\r\n    from timm.models.layers import drop_path, to_2tuple, trunc_normal_\r\nexcept:\r\n    from timm.layers import drop_path, to_2tuple, trunc_normal_\r\n    \r\nfrom .transformer import PatchDropout\r\nfrom .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast\r\n\r\nif os.getenv('ENV_TYPE') == 'deepspeed':\r\n    try:\r\n        from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint\r\n    except:\r\n        from torch.utils.checkpoint import checkpoint\r\nelse:\r\n    from torch.utils.checkpoint import checkpoint\r\n\r\ntry:\r\n    import xformers\r\n    import xformers.ops as xops\r\n    XFORMERS_IS_AVAILBLE = True\r\nexcept:\r\n    XFORMERS_IS_AVAILBLE = False\r\n\r\nclass DropPath(nn.Module):\r\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\r\n    \"\"\"\r\n    def __init__(self, drop_prob=None):\r\n        super(DropPath, self).__init__()\r\n        self.drop_prob = drop_prob\r\n\r\n    def forward(self, x):\r\n        return drop_path(x, self.drop_prob, self.training)\r\n    \r\n    def extra_repr(self) -> str:\r\n        return 'p={}'.format(self.drop_prob)\r\n\r\n\r\nclass Mlp(nn.Module):\r\n    def __init__(\r\n        self, \r\n        in_features, \r\n        hidden_features=None, \r\n        out_features=None, \r\n        act_layer=nn.GELU, \r\n        norm_layer=nn.LayerNorm, \r\n        drop=0.,\r\n        subln=False,\r\n\r\n        ):\r\n        super().__init__()\r\n        out_features = out_features or in_features\r\n        hidden_features = hidden_features or in_features\r\n        self.fc1 = nn.Linear(in_features, hidden_features)\r\n        self.act = act_layer()\r\n\r\n        self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()\r\n\r\n        self.fc2 = nn.Linear(hidden_features, out_features)\r\n        self.drop = nn.Dropout(drop)\r\n\r\n    def forward(self, x):\r\n        x = self.fc1(x)\r\n        x = self.act(x)\r\n        # x = self.drop(x)\r\n        # commit this for the orignal BERT implement \r\n        x = self.ffn_ln(x)\r\n\r\n        x = self.fc2(x)\r\n        x = self.drop(x)\r\n        return x\r\n\r\nclass SwiGLU(nn.Module):\r\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., \r\n                norm_layer=nn.LayerNorm, subln=False):\r\n        super().__init__()\r\n        out_features = out_features or in_features\r\n        hidden_features = hidden_features or in_features\r\n\r\n        self.w1 = nn.Linear(in_features, hidden_features)\r\n        self.w2 = nn.Linear(in_features, hidden_features)\r\n\r\n        self.act = act_layer()\r\n        self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()\r\n        self.w3 = nn.Linear(hidden_features, out_features)\r\n        \r\n        self.drop = nn.Dropout(drop)\r\n\r\n    def forward(self, x):\r\n        x1 = self.w1(x)\r\n        x2 = self.w2(x)\r\n        hidden = self.act(x1) * x2\r\n        x = self.ffn_ln(hidden)\r\n        x = self.w3(x)\r\n        x = self.drop(x)\r\n        return x\r\n\r\nclass Attention(nn.Module):\r\n    def __init__(\r\n            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,\r\n            proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):\r\n        super().__init__()\r\n        self.num_heads = num_heads\r\n        head_dim = dim // num_heads\r\n        if attn_head_dim is not None:\r\n            head_dim = attn_head_dim\r\n        all_head_dim = head_dim * self.num_heads\r\n        self.scale = qk_scale or head_dim ** -0.5\r\n\r\n        self.subln = subln\r\n        if self.subln:\r\n            self.q_proj = nn.Linear(dim, all_head_dim, bias=False)\r\n            self.k_proj = nn.Linear(dim, all_head_dim, bias=False)\r\n            self.v_proj = nn.Linear(dim, all_head_dim, bias=False)\r\n        else:\r\n            self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)\r\n\r\n        if qkv_bias:\r\n            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))\r\n            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))\r\n        else:\r\n            self.q_bias = None\r\n            self.v_bias = None\r\n\r\n        if window_size:\r\n            self.window_size = window_size\r\n            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\r\n            self.relative_position_bias_table = nn.Parameter(\r\n                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\r\n            # cls to token & token 2 cls & cls to cls\r\n\r\n            # get pair-wise relative position index for each token inside the window\r\n            coords_h = torch.arange(window_size[0])\r\n            coords_w = torch.arange(window_size[1])\r\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\r\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\r\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\r\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\r\n            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\r\n            relative_coords[:, :, 1] += window_size[1] - 1\r\n            relative_coords[:, :, 0] *= 2 * window_size[1] - 1\r\n            relative_position_index = \\\r\n                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)\r\n            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\r\n            relative_position_index[0, 0:] = self.num_relative_distance - 3\r\n            relative_position_index[0:, 0] = self.num_relative_distance - 2\r\n            relative_position_index[0, 0] = self.num_relative_distance - 1\r\n\r\n            self.register_buffer(\"relative_position_index\", relative_position_index)\r\n        else:\r\n            self.window_size = None\r\n            self.relative_position_bias_table = None\r\n            self.relative_position_index = None\r\n\r\n        self.attn_drop = nn.Dropout(attn_drop)\r\n        self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()\r\n        # self.proj = nn.Linear(all_head_dim, all_head_dim)\r\n        self.proj = nn.Linear(all_head_dim, dim)\r\n        self.proj_drop = nn.Dropout(proj_drop)\r\n        self.xattn = xattn\r\n        self.xattn_drop = attn_drop\r\n\r\n        self.rope = rope\r\n\r\n    def forward(self, x, rel_pos_bias=None, attn_mask=None):\r\n        B, N, C = x.shape\r\n        if self.subln: \r\n            q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)\r\n            k = F.linear(input=x, weight=self.k_proj.weight, bias=None)\r\n            v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)\r\n\r\n            q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)     # B, num_heads, N, C\r\n            k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)  \r\n            v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) \r\n        else: \r\n\r\n            qkv_bias = None\r\n            if self.q_bias is not None:\r\n                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\r\n            \r\n            qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\r\n            qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)   # 3, B, num_heads, N, C\r\n            q, k, v = qkv[0], qkv[1], qkv[2]\r\n\r\n        if self.rope:\r\n            # slightly fast impl\r\n            q_t = q[:, :, 1:, :]\r\n            ro_q_t = self.rope(q_t)\r\n            q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)\r\n\r\n            k_t = k[:, :, 1:, :]\r\n            ro_k_t = self.rope(k_t)\r\n            k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)\r\n\r\n        if self.xattn:\r\n            q = q.permute(0, 2, 1, 3)   # B, num_heads, N, C -> B, N, num_heads, C\r\n            k = k.permute(0, 2, 1, 3)\r\n            v = v.permute(0, 2, 1, 3)\r\n\r\n            x = xops.memory_efficient_attention(\r\n                q, k, v,\r\n                p=self.xattn_drop,\r\n                scale=self.scale,\r\n                )\r\n            x = x.reshape(B, N, -1)\r\n            x = self.inner_attn_ln(x)\r\n            x = self.proj(x)\r\n            x = self.proj_drop(x)\r\n        else:\r\n            q = q * self.scale\r\n            attn = (q @ k.transpose(-2, -1))\r\n\r\n            if self.relative_position_bias_table is not None:\r\n                relative_position_bias = \\\r\n                    self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\r\n                        self.window_size[0] * self.window_size[1] + 1,\r\n                        self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\r\n                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\r\n                attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)\r\n\r\n            if rel_pos_bias is not None:\r\n                attn = attn + rel_pos_bias.type_as(attn)\r\n\r\n            if attn_mask is not None:\r\n                attn_mask = attn_mask.bool()\r\n                attn = attn.masked_fill(~attn_mask[:, None, None, :], float(\"-inf\"))\r\n            \r\n            attn = attn.softmax(dim=-1)\r\n            attn = self.attn_drop(attn)\r\n\r\n            x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\r\n            x = self.inner_attn_ln(x)\r\n            x = self.proj(x)\r\n            x = self.proj_drop(x)\r\n        return x\r\n\r\n\r\nclass Block(nn.Module):\r\n\r\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\r\n                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,\r\n                 window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,\r\n                 subln=False, naiveswiglu=False):\r\n        super().__init__()\r\n        self.norm1 = norm_layer(dim)\r\n        self.attn = Attention(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\r\n            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,\r\n            xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)\r\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\r\n        self.norm2 = norm_layer(dim)\r\n        mlp_hidden_dim = int(dim * mlp_ratio)\r\n\r\n        if naiveswiglu:\r\n            self.mlp = SwiGLU(\r\n                in_features=dim, \r\n                hidden_features=mlp_hidden_dim, \r\n                subln=subln,\r\n                norm_layer=norm_layer,\r\n            )\r\n        else:\r\n            self.mlp = Mlp(\r\n                in_features=dim, \r\n                hidden_features=mlp_hidden_dim, \r\n                act_layer=act_layer,\r\n                subln=subln,\r\n                drop=drop\r\n            )\r\n\r\n        if init_values is not None and init_values > 0:\r\n            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\r\n            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\r\n        else:\r\n            self.gamma_1, self.gamma_2 = None, None\r\n\r\n        self.postnorm = postnorm\r\n\r\n    def forward(self, x, rel_pos_bias=None, attn_mask=None):\r\n        if self.gamma_1 is None:\r\n            if self.postnorm:\r\n                x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))\r\n                x = x + self.drop_path(self.norm2(self.mlp(x)))\r\n            else:\r\n                x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))\r\n                x = x + self.drop_path(self.mlp(self.norm2(x)))\r\n        else:\r\n            if self.postnorm:\r\n                x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))\r\n                x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))\r\n            else:\r\n                x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))\r\n                x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\r\n        return x\r\n\r\n\r\nclass PatchEmbed(nn.Module):\r\n    \"\"\" Image to Patch Embedding\r\n    \"\"\"\r\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\r\n        super().__init__()\r\n        img_size = to_2tuple(img_size)\r\n        patch_size = to_2tuple(patch_size)\r\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\r\n        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\r\n        self.img_size = img_size\r\n        self.patch_size = patch_size\r\n        self.num_patches = num_patches\r\n\r\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\r\n\r\n    def forward(self, x, **kwargs):\r\n        B, C, H, W = x.shape\r\n        # FIXME look at relaxing size constraints\r\n        assert H == self.img_size[0] and W == self.img_size[1], \\\r\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\r\n        x = self.proj(x).flatten(2).transpose(1, 2)\r\n        return x\r\n\r\n\r\nclass RelativePositionBias(nn.Module):\r\n\r\n    def __init__(self, window_size, num_heads):\r\n        super().__init__()\r\n        self.window_size = window_size\r\n        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3\r\n        self.relative_position_bias_table = nn.Parameter(\r\n            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH\r\n        # cls to token & token 2 cls & cls to cls\r\n\r\n        # get pair-wise relative position index for each token inside the window\r\n        coords_h = torch.arange(window_size[0])\r\n        coords_w = torch.arange(window_size[1])\r\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\r\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\r\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\r\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\r\n        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0\r\n        relative_coords[:, :, 1] += window_size[1] - 1\r\n        relative_coords[:, :, 0] *= 2 * window_size[1] - 1\r\n        relative_position_index = \\\r\n            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)\r\n        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\r\n        relative_position_index[0, 0:] = self.num_relative_distance - 3\r\n        relative_position_index[0:, 0] = self.num_relative_distance - 2\r\n        relative_position_index[0, 0] = self.num_relative_distance - 1\r\n\r\n        self.register_buffer(\"relative_position_index\", relative_position_index)\r\n\r\n    def forward(self):\r\n        relative_position_bias = \\\r\n            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\r\n                self.window_size[0] * self.window_size[1] + 1,\r\n                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH\r\n        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\r\n\r\n\r\nclass EVAVisionTransformer(nn.Module):\r\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\r\n    \"\"\"\r\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\r\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\r\n                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,\r\n                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,\r\n                 use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,\r\n                 pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):\r\n        super().__init__()\r\n\r\n        if not XFORMERS_IS_AVAILBLE:\r\n            xattn = False\r\n\r\n        self.image_size = img_size\r\n        self.num_classes = num_classes\r\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\r\n\r\n        self.patch_embed = PatchEmbed(\r\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\r\n        num_patches = self.patch_embed.num_patches\r\n\r\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\r\n        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\r\n        if use_abs_pos_emb:\r\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\r\n        else:\r\n            self.pos_embed = None\r\n        self.pos_drop = nn.Dropout(p=drop_rate)\r\n\r\n        if use_shared_rel_pos_bias:\r\n            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)\r\n        else:\r\n            self.rel_pos_bias = None\r\n\r\n        if rope:\r\n            half_head_dim = embed_dim // num_heads // 2\r\n            hw_seq_len = img_size // patch_size\r\n            self.rope = VisionRotaryEmbeddingFast(\r\n                dim=half_head_dim,\r\n                pt_seq_len=pt_hw_seq_len,\r\n                ft_seq_len=hw_seq_len if intp_freq else None,\r\n                # patch_dropout=patch_dropout\r\n            )\r\n        else: \r\n            self.rope = None\r\n\r\n        self.naiveswiglu = naiveswiglu\r\n\r\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\r\n        self.use_rel_pos_bias = use_rel_pos_bias\r\n        self.blocks = nn.ModuleList([\r\n            Block(\r\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\r\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\r\n                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,\r\n                xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)\r\n            for i in range(depth)])\r\n        self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)\r\n        self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None\r\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\r\n\r\n        if self.pos_embed is not None:\r\n            trunc_normal_(self.pos_embed, std=.02)\r\n\r\n        trunc_normal_(self.cls_token, std=.02)\r\n        # trunc_normal_(self.mask_token, std=.02)\r\n\r\n        self.apply(self._init_weights)\r\n        self.fix_init_weight()\r\n\r\n        if isinstance(self.head, nn.Linear):\r\n            trunc_normal_(self.head.weight, std=.02)\r\n            self.head.weight.data.mul_(init_scale)\r\n            self.head.bias.data.mul_(init_scale)\r\n\r\n        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\r\n        self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()\r\n\r\n        self.grad_checkpointing = grad_checkpointing\r\n\r\n    def fix_init_weight(self):\r\n        def rescale(param, layer_id):\r\n            param.div_(math.sqrt(2.0 * layer_id))\r\n\r\n        for layer_id, layer in enumerate(self.blocks):\r\n            rescale(layer.attn.proj.weight.data, layer_id + 1)\r\n            if self.naiveswiglu:\r\n                rescale(layer.mlp.w3.weight.data, layer_id + 1)\r\n            else:\r\n                rescale(layer.mlp.fc2.weight.data, layer_id + 1)\r\n\r\n    def get_cast_dtype(self) -> torch.dtype:\r\n        return self.blocks[0].mlp.fc2.weight.dtype\r\n\r\n    def _init_weights(self, m):\r\n        if isinstance(m, nn.Linear):\r\n            trunc_normal_(m.weight, std=.02)\r\n            if m.bias is not None:\r\n                nn.init.constant_(m.bias, 0)\r\n        elif isinstance(m, nn.LayerNorm):\r\n            nn.init.constant_(m.bias, 0)\r\n            nn.init.constant_(m.weight, 1.0)\r\n\r\n    def get_num_layers(self):\r\n        return len(self.blocks)\r\n    \r\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\r\n        assert unlocked_groups == 0, 'partial locking not currently supported for this model'\r\n        for param in self.parameters():\r\n            param.requires_grad = False\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        self.grad_checkpointing = enable\r\n\r\n    @torch.jit.ignore\r\n    def no_weight_decay(self):\r\n        return {'pos_embed', 'cls_token'}\r\n\r\n    def get_classifier(self):\r\n        return self.head\r\n\r\n    def reset_classifier(self, num_classes, global_pool=''):\r\n        self.num_classes = num_classes\r\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\r\n\r\n    def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):\r\n        \r\n        x = self.patch_embed(x)\r\n        batch_size, seq_len, _ = x.size()\r\n\r\n        if shuffle:\r\n            idx = torch.randperm(x.shape[1]) + 1\r\n            zero = torch.LongTensor([0, ])\r\n            idx = torch.cat([zero, idx])\r\n            pos_embed = self.pos_embed[:, idx]\r\n\r\n        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\r\n        x = torch.cat((cls_tokens, x), dim=1)\r\n        if shuffle:\r\n            x = x + pos_embed\r\n        elif self.pos_embed is not None:\r\n            x = x + self.pos_embed\r\n        x = self.pos_drop(x)\r\n\r\n        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in\r\n        if os.getenv('RoPE') == '1':\r\n            if self.training and not isinstance(self.patch_dropout, nn.Identity):\r\n                x, patch_indices_keep = self.patch_dropout(x)\r\n                self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)\r\n            else:\r\n                self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)\r\n                x = self.patch_dropout(x)\r\n        else:\r\n            x = self.patch_dropout(x)\r\n\r\n        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None\r\n        hidden_states = []\r\n        for idx, blk in enumerate(self.blocks):\r\n            if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:\r\n                hidden_states.append(x)\r\n            if self.grad_checkpointing:\r\n                x = checkpoint(blk, x, (rel_pos_bias,))\r\n            else:\r\n                x = blk(x, rel_pos_bias=rel_pos_bias)\r\n\r\n        if not return_all_features:\r\n            x = self.norm(x)\r\n            if self.fc_norm is not None:\r\n                return self.fc_norm(x.mean(1)), hidden_states\r\n            else:\r\n                return x[:, 0], hidden_states\r\n        return x\r\n\r\n    def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):\r\n        if return_all_features:\r\n            return self.forward_features(x, return_all_features, return_hidden, shuffle)\r\n        x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)\r\n        x = self.head(x)\r\n        if return_hidden:\r\n            return x, hidden_states\r\n        return x\r\n"
  },
  {
    "path": "eva_clip/factory.py",
    "content": "import json\r\nimport logging\r\nimport os\r\nimport pathlib\r\nimport re\r\nfrom copy import deepcopy\r\nfrom pathlib import Path\r\nfrom typing import Optional, Tuple, Union, Dict, Any\r\nimport torch\r\n\r\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\r\nfrom .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\\\r\n    get_cast_dtype\r\nfrom .openai import load_openai_model\r\nfrom .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model\r\nfrom .transform import image_transform\r\nfrom .tokenizer import HFTokenizer, tokenize\r\nfrom .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed\r\n\r\n\r\n_MODEL_CONFIG_PATHS = [Path(__file__).parent / f\"model_configs/\"]\r\n_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs\r\n\r\n\r\ndef _natural_key(string_):\r\n    return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_.lower())]\r\n\r\n\r\ndef _rescan_model_configs():\r\n    global _MODEL_CONFIGS\r\n\r\n    config_ext = ('.json',)\r\n    config_files = []\r\n    for config_path in _MODEL_CONFIG_PATHS:\r\n        if config_path.is_file() and config_path.suffix in config_ext:\r\n            config_files.append(config_path)\r\n        elif config_path.is_dir():\r\n            for ext in config_ext:\r\n                config_files.extend(config_path.glob(f'*{ext}'))\r\n\r\n    for cf in config_files:\r\n        with open(cf, \"r\", encoding=\"utf8\") as f:\r\n            model_cfg = json.load(f)\r\n            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):\r\n                _MODEL_CONFIGS[cf.stem] = model_cfg\r\n\r\n    _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))\r\n\r\n\r\n_rescan_model_configs()  # initial populate of model config registry\r\n\r\n\r\ndef list_models():\r\n    \"\"\" enumerate available model architectures based on config files \"\"\"\r\n    return list(_MODEL_CONFIGS.keys())\r\n\r\n\r\ndef add_model_config(path):\r\n    \"\"\" add model config path or file and update registry \"\"\"\r\n    if not isinstance(path, Path):\r\n        path = Path(path)\r\n    _MODEL_CONFIG_PATHS.append(path)\r\n    _rescan_model_configs()\r\n\r\n\r\ndef get_model_config(model_name):\r\n    if model_name in _MODEL_CONFIGS:\r\n        return deepcopy(_MODEL_CONFIGS[model_name])\r\n    else:\r\n        return None\r\n\r\n\r\ndef get_tokenizer(model_name):\r\n    config = get_model_config(model_name)\r\n    tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize\r\n    return tokenizer\r\n\r\n\r\n# loading openai CLIP weights when is_openai=True for training\r\ndef load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):\r\n    if is_openai:\r\n        model = torch.jit.load(checkpoint_path, map_location=\"cpu\").eval()\r\n        state_dict = model.state_dict()\r\n        for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\r\n            state_dict.pop(key, None)\r\n    else:\r\n        checkpoint = torch.load(checkpoint_path, map_location=map_location)\r\n        for mk in model_key.split('|'):\r\n            if isinstance(checkpoint, dict) and mk in checkpoint:\r\n                state_dict = checkpoint[mk]\r\n                break\r\n            else:\r\n                state_dict = checkpoint\r\n        if next(iter(state_dict.items()))[0].startswith('module'):\r\n            state_dict = {k[7:]: v for k, v in state_dict.items()}\r\n    \r\n    for k in skip_list:\r\n        if k in list(state_dict.keys()):\r\n            logging.info(f\"Removing key {k} from pretrained checkpoint\")\r\n            del state_dict[k]\r\n\r\n    if os.getenv('RoPE') == '1':\r\n        for k in list(state_dict.keys()):\r\n            if 'freqs_cos' in k or 'freqs_sin' in k:\r\n                del state_dict[k]\r\n    return state_dict\r\n\r\n\r\n\r\ndef load_checkpoint(model, checkpoint_path, model_key=\"model|module|state_dict\", strict=True):\r\n    state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)\r\n    # detect old format and make compatible with new format\r\n    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):\r\n        state_dict = convert_to_custom_text_state_dict(state_dict)\r\n    if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):\r\n        state_dict['logit_scale'] = state_dict['text.logit_scale']\r\n        del state_dict['text.logit_scale']\r\n\r\n    # resize_clip_pos_embed for CLIP and open CLIP\r\n    if 'visual.positional_embedding' in state_dict:\r\n        resize_clip_pos_embed(state_dict, model)\r\n    # specified to eva_vit_model\r\n    elif 'visual.pos_embed' in state_dict:\r\n        resize_evaclip_pos_embed(state_dict, model)\r\n\r\n    # resize_clip_pos_embed(state_dict, model)\r\n    incompatible_keys = model.load_state_dict(state_dict, strict=strict)\r\n    logging.info(f\"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}\")\r\n    return incompatible_keys\r\n\r\ndef load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):\r\n    state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)\r\n\r\n    for k in list(state_dict.keys()):\r\n        if not k.startswith('visual.'):\r\n            del state_dict[k]\r\n    for k in list(state_dict.keys()):\r\n        if k.startswith('visual.'):\r\n            new_k = k[7:]\r\n            state_dict[new_k] = state_dict[k]\r\n            del state_dict[k]\r\n    return state_dict\r\n\r\ndef load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):\r\n    state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)\r\n\r\n    for k in list(state_dict.keys()):\r\n        if k.startswith('visual.'):\r\n            del state_dict[k]\r\n    return state_dict\r\n\r\ndef get_pretrained_tag(pretrained_model):\r\n    pretrained_model = pretrained_model.lower()\r\n    if \"laion\" in pretrained_model or \"open_clip\" in pretrained_model:\r\n        return \"open_clip\"\r\n    elif \"openai\" in pretrained_model:\r\n        return \"clip\"\r\n    elif \"eva\" in pretrained_model and \"clip\" in pretrained_model:\r\n        return \"eva_clip\"\r\n    else:\r\n        return \"other\"\r\n\r\ndef load_pretrained_checkpoint(\r\n        model,\r\n        visual_checkpoint_path,\r\n        text_checkpoint_path,\r\n        strict=True,\r\n        visual_model=None,\r\n        text_model=None,\r\n        model_key=\"model|module|state_dict\",\r\n        skip_list=[]):\r\n    visual_tag = get_pretrained_tag(visual_model)\r\n    text_tag = get_pretrained_tag(text_model)\r\n\r\n    logging.info(f\"num of model state_dict keys: {len(model.state_dict().keys())}\")\r\n    visual_incompatible_keys, text_incompatible_keys = None, None\r\n    if visual_checkpoint_path:\r\n        if visual_tag == \"eva_clip\" or visual_tag == \"open_clip\":\r\n            visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)\r\n        elif visual_tag == \"clip\":\r\n            visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)\r\n        else:\r\n            visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)\r\n    \r\n        # resize_clip_pos_embed for CLIP and open CLIP\r\n        if 'positional_embedding' in visual_state_dict:\r\n            resize_visual_pos_embed(visual_state_dict, model)\r\n        # specified to EVA model\r\n        elif 'pos_embed' in visual_state_dict:\r\n            resize_eva_pos_embed(visual_state_dict, model)\r\n\r\n        visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)\r\n        logging.info(f\"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}\")\r\n        logging.info(f\"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}\")\r\n\r\n    if text_checkpoint_path:\r\n        if text_tag == \"eva_clip\" or text_tag == \"open_clip\":\r\n            text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)\r\n        elif text_tag == \"clip\":\r\n            text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)\r\n        else:\r\n            text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)\r\n\r\n        text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)\r\n        \r\n        logging.info(f\"num of loaded text_state_dict keys: {len(text_state_dict.keys())}\")\r\n        logging.info(f\"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}\")\r\n\r\n    return visual_incompatible_keys, text_incompatible_keys\r\n\r\ndef create_model(\r\n        model_name: str,\r\n        pretrained: Optional[str] = None,\r\n        precision: str = 'fp32',\r\n        device: Union[str, torch.device] = 'cpu',\r\n        jit: bool = False,\r\n        force_quick_gelu: bool = False,\r\n        force_custom_clip: bool = False,\r\n        force_patch_dropout: Optional[float] = None,\r\n        pretrained_image: str = '',\r\n        pretrained_text: str = '',\r\n        pretrained_hf: bool = True,\r\n        pretrained_visual_model: str = None,\r\n        pretrained_text_model: str = None,\r\n        cache_dir: Optional[str] = None,\r\n        skip_list: list  = [],\r\n):\r\n    model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names\r\n    if isinstance(device, str):\r\n        device = torch.device(device)\r\n\r\n    if pretrained and pretrained.lower() == 'openai':\r\n        logging.info(f'Loading pretrained {model_name} from OpenAI.')\r\n        model = load_openai_model(\r\n            model_name,\r\n            precision=precision,\r\n            device=device,\r\n            jit=jit,\r\n            cache_dir=cache_dir,\r\n        )\r\n    else:\r\n        model_cfg = get_model_config(model_name)\r\n        if model_cfg is not None:\r\n            logging.info(f'Loaded {model_name} model config.')\r\n        else:\r\n            logging.error(f'Model config for {model_name} not found; available models {list_models()}.')\r\n            raise RuntimeError(f'Model config for {model_name} not found.')\r\n\r\n        if 'rope' in model_cfg.get('vision_cfg', {}):\r\n            if model_cfg['vision_cfg']['rope']:\r\n                os.environ['RoPE'] = \"1\"\r\n        else:\r\n            os.environ['RoPE'] = \"0\"\r\n\r\n        if force_quick_gelu:\r\n            # override for use of QuickGELU on non-OpenAI transformer models\r\n            model_cfg[\"quick_gelu\"] = True\r\n        \r\n        if force_patch_dropout is not None:\r\n            # override the default patch dropout value\r\n            model_cfg['vision_cfg'][\"patch_dropout\"] = force_patch_dropout\r\n\r\n        cast_dtype = get_cast_dtype(precision)\r\n        custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])\r\n\r\n\r\n        if custom_clip:\r\n            if 'hf_model_name' in model_cfg.get('text_cfg', {}):\r\n                model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf\r\n            model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)\r\n        else:\r\n            model = CLIP(**model_cfg, cast_dtype=cast_dtype)\r\n\r\n        pretrained_cfg = {}\r\n        if pretrained:\r\n            checkpoint_path = ''\r\n            pretrained_cfg = get_pretrained_cfg(model_name, pretrained)\r\n            if pretrained_cfg:\r\n                checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)\r\n            elif os.path.exists(pretrained):\r\n                checkpoint_path = pretrained\r\n\r\n            if checkpoint_path:\r\n                logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')\r\n                load_checkpoint(model,\r\n                               checkpoint_path,\r\n                               model_key=\"model|module|state_dict\",\r\n                               strict=False\r\n                               ) \r\n            else:\r\n                error_str = (\r\n                    f'Pretrained weights ({pretrained}) not found for model {model_name}.'\r\n                    f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')\r\n                logging.warning(error_str)\r\n                raise RuntimeError(error_str)\r\n        else:\r\n            visual_checkpoint_path = ''\r\n            text_checkpoint_path = ''\r\n            \r\n            if pretrained_image:\r\n                pretrained_visual_model = pretrained_visual_model.replace('/', '-')  # for callers using old naming with / in ViT names\r\n                pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)\r\n                if 'timm_model_name' in model_cfg.get('vision_cfg', {}):\r\n                    # pretrained weight loading for timm models set via vision_cfg\r\n                    model_cfg['vision_cfg']['timm_model_pretrained'] = True\r\n                elif pretrained_image_cfg:\r\n                    visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)\r\n                elif os.path.exists(pretrained_image):\r\n                    visual_checkpoint_path = pretrained_image\r\n                else:\r\n                    logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')\r\n                    raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')\r\n\r\n            if pretrained_text:\r\n                pretrained_text_model = pretrained_text_model.replace('/', '-')  # for callers using old naming with / in ViT names\r\n                pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)\r\n                if pretrained_image_cfg:\r\n                    text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)\r\n                elif os.path.exists(pretrained_text):\r\n                    text_checkpoint_path = pretrained_text\r\n                else:\r\n                    logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')\r\n                    raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')\r\n            \r\n            if visual_checkpoint_path:\r\n                logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')\r\n            if text_checkpoint_path:\r\n                logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')\r\n\r\n            if visual_checkpoint_path or text_checkpoint_path:\r\n                load_pretrained_checkpoint(\r\n                    model,\r\n                    visual_checkpoint_path,\r\n                    text_checkpoint_path,\r\n                    strict=False,\r\n                    visual_model=pretrained_visual_model,\r\n                    text_model=pretrained_text_model,\r\n                    model_key=\"model|module|state_dict\",\r\n                    skip_list=skip_list\r\n                )\r\n        \r\n        if \"fp16\" in precision or \"bf16\" in precision:\r\n            logging.info(f'convert precision to {precision}')\r\n            model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)\r\n\r\n        model.to(device=device)\r\n\r\n        # set image / mean metadata from pretrained_cfg if available, or use default\r\n        model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN\r\n        model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD\r\n\r\n        if jit:\r\n            model = torch.jit.script(model)\r\n\r\n    return model\r\n\r\n\r\ndef create_model_and_transforms(\r\n        model_name: str,\r\n        pretrained: Optional[str] = None,\r\n        precision: str = 'fp32',\r\n        device: Union[str, torch.device] = 'cpu',\r\n        jit: bool = False,\r\n        force_quick_gelu: bool = False,\r\n        force_custom_clip: bool = False,\r\n        force_patch_dropout: Optional[float] = None,\r\n        pretrained_image: str = '',\r\n        pretrained_text: str = '',\r\n        pretrained_hf: bool = True,\r\n        pretrained_visual_model: str = None,\r\n        pretrained_text_model: str = None,\r\n        image_mean: Optional[Tuple[float, ...]] = None,\r\n        image_std: Optional[Tuple[float, ...]] = None,\r\n        cache_dir: Optional[str] = None,\r\n        skip_list: list = [],\r\n):\r\n    model = create_model(\r\n        model_name,\r\n        pretrained,\r\n        precision=precision,\r\n        device=device,\r\n        jit=jit,\r\n        force_quick_gelu=force_quick_gelu,\r\n        force_custom_clip=force_custom_clip,\r\n        force_patch_dropout=force_patch_dropout,\r\n        pretrained_image=pretrained_image,\r\n        pretrained_text=pretrained_text,\r\n        pretrained_hf=pretrained_hf,\r\n        pretrained_visual_model=pretrained_visual_model,\r\n        pretrained_text_model=pretrained_text_model,\r\n        cache_dir=cache_dir,\r\n        skip_list=skip_list,\r\n    )\r\n\r\n    image_mean = image_mean or getattr(model.visual, 'image_mean', None)\r\n    image_std = image_std or getattr(model.visual, 'image_std', None)\r\n    preprocess_train = image_transform(\r\n        model.visual.image_size,\r\n        is_train=True,\r\n        mean=image_mean,\r\n        std=image_std\r\n    )\r\n    preprocess_val = image_transform(\r\n        model.visual.image_size,\r\n        is_train=False,\r\n        mean=image_mean,\r\n        std=image_std\r\n    )\r\n\r\n    return model, preprocess_train, preprocess_val\r\n\r\n\r\ndef create_transforms(\r\n        model_name: str,\r\n        pretrained: Optional[str] = None,\r\n        precision: str = 'fp32',\r\n        device: Union[str, torch.device] = 'cpu',\r\n        jit: bool = False,\r\n        force_quick_gelu: bool = False,\r\n        force_custom_clip: bool = False,\r\n        force_patch_dropout: Optional[float] = None,\r\n        pretrained_image: str = '',\r\n        pretrained_text: str = '',\r\n        pretrained_hf: bool = True,\r\n        pretrained_visual_model: str = None,\r\n        pretrained_text_model: str = None,\r\n        image_mean: Optional[Tuple[float, ...]] = None,\r\n        image_std: Optional[Tuple[float, ...]] = None,\r\n        cache_dir: Optional[str] = None,\r\n        skip_list: list = [],\r\n):\r\n    model = create_model(\r\n        model_name,\r\n        pretrained,\r\n        precision=precision,\r\n        device=device,\r\n        jit=jit,\r\n        force_quick_gelu=force_quick_gelu,\r\n        force_custom_clip=force_custom_clip,\r\n        force_patch_dropout=force_patch_dropout,\r\n        pretrained_image=pretrained_image,\r\n        pretrained_text=pretrained_text,\r\n        pretrained_hf=pretrained_hf,\r\n        pretrained_visual_model=pretrained_visual_model,\r\n        pretrained_text_model=pretrained_text_model,\r\n        cache_dir=cache_dir,\r\n        skip_list=skip_list,\r\n    )\r\n\r\n\r\n    image_mean = image_mean or getattr(model.visual, 'image_mean', None)\r\n    image_std = image_std or getattr(model.visual, 'image_std', None)\r\n    preprocess_train = image_transform(\r\n        model.visual.image_size,\r\n        is_train=True,\r\n        mean=image_mean,\r\n        std=image_std\r\n    )\r\n    preprocess_val = image_transform(\r\n        model.visual.image_size,\r\n        is_train=False,\r\n        mean=image_mean,\r\n        std=image_std\r\n    )\r\n    del model\r\n\r\n    return preprocess_train, preprocess_val\r\n\r\ndef create_model_from_pretrained(\r\n        model_name: str,\r\n        pretrained: str,\r\n        precision: str = 'fp32',\r\n        device: Union[str, torch.device] = 'cpu',\r\n        jit: bool = False,\r\n        force_quick_gelu: bool = False,\r\n        force_custom_clip: bool = False,\r\n        force_patch_dropout: Optional[float] = None,\r\n        return_transform: bool = True,\r\n        image_mean: Optional[Tuple[float, ...]] = None,\r\n        image_std: Optional[Tuple[float, ...]] = None,\r\n        cache_dir: Optional[str] = None,\r\n        is_frozen: bool = False,\r\n):\r\n    if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):\r\n        raise RuntimeError(\r\n            f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'\r\n            f' Use open_clip.list_pretrained() to find one.')\r\n\r\n    model = create_model(\r\n        model_name,\r\n        pretrained,\r\n        precision=precision,\r\n        device=device,\r\n        jit=jit,\r\n        force_quick_gelu=force_quick_gelu,\r\n        force_custom_clip=force_custom_clip,\r\n        force_patch_dropout=force_patch_dropout,\r\n        cache_dir=cache_dir,\r\n    )\r\n\r\n    if is_frozen:\r\n        for param in model.parameters():\r\n            param.requires_grad = False\r\n\r\n    if not return_transform:\r\n        return model\r\n\r\n    image_mean = image_mean or getattr(model.visual, 'image_mean', None)\r\n    image_std = image_std or getattr(model.visual, 'image_std', None)\r\n    preprocess = image_transform(\r\n        model.visual.image_size,\r\n        is_train=False,\r\n        mean=image_mean,\r\n        std=image_std\r\n    )\r\n\r\n    return model, preprocess\r\n"
  },
  {
    "path": "eva_clip/hf_configs.py",
    "content": "# HF architecture dict:\r\narch_dict = {\r\n  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\r\n  \"roberta\": {\r\n      \"config_names\": {\r\n          \"context_length\": \"max_position_embeddings\",\r\n          \"vocab_size\": \"vocab_size\",\r\n          \"width\": \"hidden_size\",\r\n          \"heads\": \"num_attention_heads\",\r\n          \"layers\": \"num_hidden_layers\",\r\n          \"layer_attr\": \"layer\",\r\n          \"token_embeddings_attr\": \"embeddings\"\r\n      },\r\n      \"pooler\": \"mean_pooler\",\r\n  },\r\n  # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig\r\n  \"xlm-roberta\": {\r\n      \"config_names\": {\r\n          \"context_length\": \"max_position_embeddings\",\r\n          \"vocab_size\": \"vocab_size\",\r\n          \"width\": \"hidden_size\",\r\n          \"heads\": \"num_attention_heads\",\r\n          \"layers\": \"num_hidden_layers\",\r\n          \"layer_attr\": \"layer\",\r\n          \"token_embeddings_attr\": \"embeddings\"\r\n      },\r\n      \"pooler\": \"mean_pooler\",\r\n  },\r\n  # https://huggingface.co/docs/transformers/model_doc/mt5#mt5\r\n  \"mt5\": {\r\n      \"config_names\": {\r\n          # unlimited seqlen\r\n          # https://github.com/google-research/text-to-text-transfer-transformer/issues/273\r\n          # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374\r\n          \"context_length\": \"\",\r\n          \"vocab_size\": \"vocab_size\",\r\n          \"width\": \"d_model\",\r\n          \"heads\": \"num_heads\",\r\n          \"layers\": \"num_layers\",\r\n          \"layer_attr\": \"block\",\r\n          \"token_embeddings_attr\": \"embed_tokens\"\r\n      },\r\n      \"pooler\": \"mean_pooler\",\r\n  },\r\n  \"bert\": {\r\n    \"config_names\": {\r\n      \"context_length\": \"max_position_embeddings\",\r\n      \"vocab_size\": \"vocab_size\",\r\n      \"width\": \"hidden_size\",\r\n      \"heads\": \"num_attention_heads\",\r\n      \"layers\": \"num_hidden_layers\",\r\n      \"layer_attr\": \"layer\",\r\n      \"token_embeddings_attr\": \"embeddings\"\r\n    },\r\n    \"pooler\": \"mean_pooler\",\r\n  }\r\n}\r\n"
  },
  {
    "path": "eva_clip/hf_model.py",
    "content": "\"\"\" huggingface model adapter\r\n\r\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.\r\n\"\"\"\r\n\r\nimport re\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.nn import functional as F\r\nfrom torch import TensorType\r\ntry:\r\n    import transformers\r\n    from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig\r\n    from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \\\r\n        BaseModelOutputWithPoolingAndCrossAttentions\r\nexcept ImportError as e:\r\n    transformers = None\r\n\r\n\r\n    class BaseModelOutput:\r\n        pass\r\n\r\n\r\n    class PretrainedConfig:\r\n        pass\r\n\r\nfrom .hf_configs import arch_dict\r\n\r\n# utils\r\ndef _camel2snake(s):\r\n    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()\r\n\r\n# TODO: ?last - for gpt-like models\r\n_POOLERS = {}\r\n\r\ndef register_pooler(cls):\r\n    \"\"\"Decorator registering pooler class\"\"\"\r\n    _POOLERS[_camel2snake(cls.__name__)] = cls\r\n    return cls\r\n\r\n\r\n@register_pooler\r\nclass MeanPooler(nn.Module):\r\n    \"\"\"Mean pooling\"\"\"\r\n    def forward(self, x:BaseModelOutput, attention_mask:TensorType):\r\n        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)\r\n        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)\r\n\r\n@register_pooler\r\nclass MaxPooler(nn.Module):\r\n    \"\"\"Max pooling\"\"\"\r\n    def forward(self, x:BaseModelOutput, attention_mask:TensorType):\r\n        masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)\r\n        return masked_output.max(1).values\r\n\r\n@register_pooler\r\nclass ClsPooler(nn.Module):\r\n    \"\"\"CLS token pooling\"\"\"\r\n    def __init__(self, use_pooler_output=True):\r\n        super().__init__()\r\n        self.cls_token_position = 0\r\n        self.use_pooler_output = use_pooler_output\r\n\r\n    def forward(self, x:BaseModelOutput, attention_mask:TensorType):\r\n        \r\n        if (self.use_pooler_output and \r\n            isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and\r\n            (x.pooler_output is not None)\r\n            ):\r\n            return x.pooler_output\r\n        \r\n        return x.last_hidden_state[:, self.cls_token_position, :]\r\n\r\nclass HFTextEncoder(nn.Module):\r\n    \"\"\"HuggingFace model adapter\"\"\"\r\n    def __init__(\r\n            self, \r\n            model_name_or_path: str,\r\n            output_dim: int,\r\n            tokenizer_name: str = None,\r\n            config: PretrainedConfig = None,\r\n            pooler_type: str = None,\r\n            proj: str = None,\r\n            pretrained: bool = True,\r\n            masked_language_modeling: bool = False):\r\n        super().__init__()\r\n\r\n        self.output_dim = output_dim\r\n\r\n        # TODO: find better way to get this information\r\n        uses_transformer_pooler = (pooler_type == \"cls_pooler\")\r\n\r\n        if transformers is None:\r\n            raise RuntimeError(\"Please `pip install transformers` to use pre-trained HuggingFace models\")\r\n        if config is None:\r\n            self.config = AutoConfig.from_pretrained(model_name_or_path)\r\n            if masked_language_modeling:\r\n                create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (\r\n                    AutoModelForMaskedLM.from_config, self.config)\r\n            else:\r\n                create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (\r\n                    AutoModel.from_config, self.config)\r\n            # TODO: do all model configs have this attribute? PretrainedConfig does so yes??\r\n            if hasattr(self.config, \"is_encoder_decoder\") and self.config.is_encoder_decoder:\r\n                self.transformer = create_func(model_args)\r\n                self.transformer = self.transformer.encoder\r\n            else:\r\n                self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)\r\n        else:\r\n            self.config = config\r\n            if masked_language_modeling:\r\n                self.transformer = AutoModelForMaskedLM.from_config(config)\r\n            else:\r\n                self.transformer = AutoModel.from_config(config)\r\n\r\n        if pooler_type is None: # get default arch pooler\r\n            self.pooler = _POOLERS[(arch_dict[self.config.model_type][\"pooler\"])]()\r\n        else:\r\n            self.pooler = _POOLERS[pooler_type]()\r\n\r\n        d_model = getattr(self.config, arch_dict[self.config.model_type][\"config_names\"][\"width\"])\r\n        if (d_model == output_dim) and (proj is None): # do we always need a proj?\r\n            self.proj = nn.Identity()\r\n        elif proj == 'linear':\r\n            self.proj = nn.Linear(d_model, output_dim, bias=False)\r\n        elif proj == 'mlp':\r\n            hidden_size = (d_model + output_dim) // 2\r\n            self.proj = nn.Sequential(\r\n                nn.Linear(d_model, hidden_size, bias=False),\r\n                nn.GELU(),\r\n                nn.Linear(hidden_size, output_dim, bias=False),\r\n            )\r\n\r\n        # self.itm_proj = nn.Linear(d_model, 2, bias=False)\r\n        # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)\r\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\r\n\r\n    # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:\r\n    #     image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)  \r\n    #     attn_mask = (x != self.config.pad_token_id).long()\r\n    #     out = self.transformer(\r\n    #         input_ids=x, \r\n    #         attention_mask=attn_mask,\r\n    #         encoder_hidden_states = image_embeds,\r\n    #         encoder_attention_mask = image_atts,\r\n    #         )\r\n    #     pooled_out = self.pooler(out, attn_mask)\r\n\r\n    #     return self.itm_proj(pooled_out)\r\n\r\n    def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):\r\n        if masked_indices is None:                                       \r\n            masked_indices = torch.bernoulli(probability_matrix).bool()\r\n                                               \r\n        masked_indices[input_ids == self.tokenizer.pad_token_id] = False\r\n        masked_indices[input_ids == self.tokenizer.cls_token_id] = False\r\n        \r\n        if targets is not None:\r\n            targets[~masked_indices] = -100 # We only compute loss on masked tokens            \r\n\r\n        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\r\n        indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices\r\n        input_ids[indices_replaced] = self.tokenizer.mask_token_id\r\n\r\n        # 10% of the time, we replace masked input tokens with random word\r\n        indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced\r\n        random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)\r\n        input_ids[indices_random] = random_words[indices_random]                     \r\n        # The rest of the time (10% of the time) we keep the masked input tokens unchanged   \r\n        \r\n        if targets is not None:\r\n            return input_ids, targets\r\n        else:\r\n            return input_ids\r\n\r\n    def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):\r\n        labels = input_ids.clone()\r\n        attn_mask = (input_ids != self.config.pad_token_id).long()\r\n        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) \r\n        vocab_size = getattr(self.config, arch_dict[self.config.model_type][\"config_names\"][\"vocab_size\"])\r\n        probability_matrix = torch.full(labels.shape, mlm_probability)\r\n        input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,\r\n                                      probability_matrix = probability_matrix)\r\n        mlm_output = self.transformer(input_ids,\r\n                        attention_mask = attn_mask,\r\n                        encoder_hidden_states = image_embeds,\r\n                        encoder_attention_mask = image_atts,\r\n                        return_dict = True,\r\n                        labels = labels,\r\n                    )\r\n        return mlm_output.loss\r\n        # mlm_output = self.transformer(input_ids,\r\n        #                 attention_mask = attn_mask,\r\n        #                 encoder_hidden_states = image_embeds,\r\n        #                 encoder_attention_mask = image_atts,\r\n        #                 return_dict = True,\r\n        #             ).last_hidden_state\r\n        # logits = self.mlm_proj(mlm_output)\r\n\r\n        # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)\r\n        # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)\r\n        # labels = labels[:, 1:].contiguous().view(-1)\r\n\r\n        # mlm_loss = F.cross_entropy(\r\n        #     logits,\r\n        #     labels,\r\n        #     # label_smoothing=0.1,\r\n        # )\r\n        # return mlm_loss\r\n\r\n\r\n    def forward(self, x:TensorType) -> TensorType:\r\n        attn_mask = (x != self.config.pad_token_id).long()\r\n        out = self.transformer(input_ids=x, attention_mask=attn_mask)\r\n        pooled_out = self.pooler(out, attn_mask)\r\n\r\n        return self.proj(pooled_out)\r\n\r\n    def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):\r\n        if not unlocked_layers: # full freezing\r\n             for n, p in self.transformer.named_parameters():\r\n                 p.requires_grad = (not freeze_layer_norm) if \"LayerNorm\" in n.split(\".\") else False\r\n             return\r\n\r\n        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer\r\n        layer_list = getattr(encoder, arch_dict[self.config.model_type][\"config_names\"][\"layer_attr\"])\r\n        print(f\"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model\")\r\n        embeddings = getattr(\r\n            self.transformer, arch_dict[self.config.model_type][\"config_names\"][\"token_embeddings_attr\"])\r\n        modules = [embeddings, *layer_list][:-unlocked_layers]\r\n        # freeze layers\r\n        for module in modules:\r\n            for n, p in module.named_parameters():\r\n                p.requires_grad = (not freeze_layer_norm) if \"LayerNorm\" in n.split(\".\") else False\r\n\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        self.transformer.gradient_checkpointing_enable()\r\n\r\n    def get_num_layers(self):\r\n        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer\r\n        layer_list = getattr(encoder, arch_dict[self.config.model_type][\"config_names\"][\"layer_attr\"])\r\n        return len(layer_list)\r\n\r\n    def init_parameters(self):\r\n        pass\r\n"
  },
  {
    "path": "eva_clip/loss.py",
    "content": "import math\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.nn import functional as F\r\n\r\ntry:\r\n    import torch.distributed.nn\r\n    from torch import distributed as dist\r\n    has_distributed = True\r\nexcept ImportError:\r\n    has_distributed = False\r\n\r\ntry:\r\n    import horovod.torch as hvd\r\nexcept ImportError:\r\n    hvd = None\r\n\r\nfrom timm.loss import LabelSmoothingCrossEntropy\r\n\r\n\r\ndef gather_features(\r\n        image_features,\r\n        text_features,\r\n        local_loss=False,\r\n        gather_with_grad=False,\r\n        rank=0,\r\n        world_size=1,\r\n        use_horovod=False\r\n):\r\n    assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'\r\n    if use_horovod:\r\n        assert hvd is not None, 'Please install horovod'\r\n        if gather_with_grad:\r\n            all_image_features = hvd.allgather(image_features)\r\n            all_text_features = hvd.allgather(text_features)\r\n        else:\r\n            with torch.no_grad():\r\n                all_image_features = hvd.allgather(image_features)\r\n                all_text_features = hvd.allgather(text_features)\r\n            if not local_loss:\r\n                # ensure grads for local rank when all_* features don't have a gradient\r\n                gathered_image_features = list(all_image_features.chunk(world_size, dim=0))\r\n                gathered_text_features = list(all_text_features.chunk(world_size, dim=0))\r\n                gathered_image_features[rank] = image_features\r\n                gathered_text_features[rank] = text_features\r\n                all_image_features = torch.cat(gathered_image_features, dim=0)\r\n                all_text_features = torch.cat(gathered_text_features, dim=0)\r\n    else:\r\n        # We gather tensors from all gpus\r\n        if gather_with_grad:\r\n            all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)\r\n            all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)\r\n            # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)\r\n            # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)\r\n        else:\r\n            gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]\r\n            gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]\r\n            dist.all_gather(gathered_image_features, image_features)\r\n            dist.all_gather(gathered_text_features, text_features)\r\n            if not local_loss:\r\n                # ensure grads for local rank when all_* features don't have a gradient\r\n                gathered_image_features[rank] = image_features\r\n                gathered_text_features[rank] = text_features\r\n            all_image_features = torch.cat(gathered_image_features, dim=0)\r\n            all_text_features = torch.cat(gathered_text_features, dim=0)\r\n\r\n    return all_image_features, all_text_features\r\n\r\n\r\nclass ClipLoss(nn.Module):\r\n\r\n    def __init__(\r\n            self,\r\n            local_loss=False,\r\n            gather_with_grad=False,\r\n            cache_labels=False,\r\n            rank=0,\r\n            world_size=1,\r\n            use_horovod=False,\r\n            smoothing=0.,\r\n    ):\r\n        super().__init__()\r\n        self.local_loss = local_loss\r\n        self.gather_with_grad = gather_with_grad\r\n        self.cache_labels = cache_labels\r\n        self.rank = rank\r\n        self.world_size = world_size\r\n        self.use_horovod = use_horovod\r\n        self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None\r\n\r\n        # cache state\r\n        self.prev_num_logits = 0\r\n        self.labels = {}\r\n\r\n    def forward(self, image_features, text_features, logit_scale=1.):\r\n        device = image_features.device\r\n        if self.world_size > 1:\r\n            all_image_features, all_text_features = gather_features(\r\n                image_features, text_features,\r\n                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)\r\n\r\n            if self.local_loss:\r\n                logits_per_image = logit_scale * image_features @ all_text_features.T\r\n                logits_per_text = logit_scale * text_features @ all_image_features.T\r\n            else:\r\n                logits_per_image = logit_scale * all_image_features @ all_text_features.T\r\n                logits_per_text = logits_per_image.T\r\n        else:\r\n            logits_per_image = logit_scale * image_features @ text_features.T\r\n            logits_per_text = logit_scale * text_features @ image_features.T\r\n        # calculated ground-truth and cache if enabled\r\n        num_logits = logits_per_image.shape[0]\r\n        if self.prev_num_logits != num_logits or device not in self.labels:\r\n            labels = torch.arange(num_logits, device=device, dtype=torch.long)\r\n            if self.world_size > 1 and self.local_loss:\r\n                labels = labels + num_logits * self.rank\r\n            if self.cache_labels:\r\n                self.labels[device] = labels\r\n                self.prev_num_logits = num_logits\r\n        else:\r\n            labels = self.labels[device]\r\n        \r\n        if self.label_smoothing_cross_entropy:\r\n            total_loss = (\r\n                self.label_smoothing_cross_entropy(logits_per_image, labels) +\r\n                self.label_smoothing_cross_entropy(logits_per_text, labels)\r\n                ) / 2\r\n        else:\r\n            total_loss = (\r\n                F.cross_entropy(logits_per_image, labels) +\r\n                F.cross_entropy(logits_per_text, labels)\r\n                ) / 2\r\n            \r\n        acc = None\r\n        i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)\r\n        t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)\r\n        acc = {\"i2t\": i2t_acc, \"t2i\": t2i_acc}\r\n        return total_loss, acc"
  },
  {
    "path": "eva_clip/model.py",
    "content": "\"\"\" CLIP Model\r\n\r\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\r\n\"\"\"\r\nimport os\r\nfrom dataclasses import dataclass\r\nfrom typing import Optional, Tuple, Union\r\nfrom functools import partial\r\n\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom torch import nn\r\n\r\ntry:\r\n    from .hf_model import HFTextEncoder\r\nexcept:\r\n    HFTextEncoder = None\r\nfrom .modified_resnet import ModifiedResNet\r\nfrom .timm_model import TimmModel\r\nfrom .eva_vit_model import EVAVisionTransformer\r\nfrom .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer\r\n\r\ntry:\r\n    from apex.normalization import FusedLayerNorm\r\nexcept:\r\n    FusedLayerNorm = LayerNorm\r\n    print(\"Nvidia APEX normalization not installed, using PyTorch LayerNorm\")\r\n\r\ntry:\r\n    import xformers.ops as xops\r\nexcept ImportError:\r\n    xops = None\r\n    #print(\"Please 'pip install xformers'\")\r\n\r\n@dataclass\r\nclass CLIPVisionCfg:\r\n    layers: Union[Tuple[int, int, int, int], int] = 12\r\n    width: int = 768\r\n    head_width: int = 64\r\n    mlp_ratio: float = 4.0\r\n    patch_size: int = 16\r\n    image_size: Union[Tuple[int, int], int] = 224\r\n    ls_init_value: Optional[float] = None  # layer scale initial value\r\n    patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results\r\n    global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)\r\n    drop_path_rate: Optional[float] = None  # drop path rate\r\n    timm_model_name: str = None  # a valid model name overrides layers, width, patch_size\r\n    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model\r\n    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\r\n    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')\r\n    timm_proj_bias: bool = False  # enable bias final projection\r\n    eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size\r\n    qkv_bias: bool = True\r\n    fusedLN: bool = False\r\n    xattn: bool = False\r\n    postnorm: bool = False\r\n    rope: bool = False\r\n    pt_hw_seq_len: int = 16   # 224/14\r\n    intp_freq: bool = False\r\n    naiveswiglu: bool = False\r\n    subln: bool = False\r\n\r\n\r\n@dataclass\r\nclass CLIPTextCfg:\r\n    context_length: int = 77\r\n    vocab_size: int = 49408\r\n    width: int = 512\r\n    heads: int = 8\r\n    layers: int = 12\r\n    ls_init_value: Optional[float] = None  # layer scale initial value\r\n    hf_model_name: str = None\r\n    hf_tokenizer_name: str = None\r\n    hf_model_pretrained: bool = True\r\n    proj: str = 'mlp'\r\n    pooler_type: str = 'mean_pooler'\r\n    masked_language_modeling: bool = False\r\n    fusedLN: bool = False\r\n    xattn: bool = False\r\n    attn_mask: bool = True\r\n\r\ndef get_cast_dtype(precision: str):\r\n    cast_dtype = None\r\n    if precision == 'bf16':\r\n        cast_dtype = torch.bfloat16\r\n    elif precision == 'fp16':\r\n        cast_dtype = torch.float16\r\n    return cast_dtype\r\n\r\n\r\ndef _build_vision_tower(\r\n        embed_dim: int,\r\n        vision_cfg: CLIPVisionCfg,\r\n        quick_gelu: bool = False,\r\n        cast_dtype: Optional[torch.dtype] = None\r\n):\r\n    if isinstance(vision_cfg, dict):\r\n        vision_cfg = CLIPVisionCfg(**vision_cfg)\r\n\r\n    # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more\r\n    # memory efficient in recent PyTorch releases (>= 1.10).\r\n    # NOTE: timm models always use native GELU regardless of quick_gelu flag.\r\n    act_layer = QuickGELU if quick_gelu else nn.GELU\r\n\r\n    if vision_cfg.eva_model_name:\r\n        vision_heads = vision_cfg.width // vision_cfg.head_width\r\n        norm_layer = LayerNorm\r\n        \r\n        visual = EVAVisionTransformer(\r\n            img_size=vision_cfg.image_size,\r\n            patch_size=vision_cfg.patch_size,\r\n            num_classes=embed_dim,\r\n            use_mean_pooling=vision_cfg.global_average_pool, #False\r\n            init_values=vision_cfg.ls_init_value,\r\n            patch_dropout=vision_cfg.patch_dropout,\r\n            embed_dim=vision_cfg.width,\r\n            depth=vision_cfg.layers,\r\n            num_heads=vision_heads,\r\n            mlp_ratio=vision_cfg.mlp_ratio,\r\n            qkv_bias=vision_cfg.qkv_bias,\r\n            drop_path_rate=vision_cfg.drop_path_rate,\r\n            norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),\r\n            xattn=vision_cfg.xattn,\r\n            rope=vision_cfg.rope,\r\n            postnorm=vision_cfg.postnorm,\r\n            pt_hw_seq_len= vision_cfg.pt_hw_seq_len,   # 224/14\r\n            intp_freq= vision_cfg.intp_freq,\r\n            naiveswiglu= vision_cfg.naiveswiglu,\r\n            subln= vision_cfg.subln\r\n        )\r\n    elif vision_cfg.timm_model_name:\r\n        visual = TimmModel(\r\n            vision_cfg.timm_model_name,\r\n            pretrained=vision_cfg.timm_model_pretrained,\r\n            pool=vision_cfg.timm_pool,\r\n            proj=vision_cfg.timm_proj,\r\n            proj_bias=vision_cfg.timm_proj_bias,\r\n            embed_dim=embed_dim,\r\n            image_size=vision_cfg.image_size\r\n        )\r\n        act_layer = nn.GELU  # so that text transformer doesn't use QuickGELU w/ timm models\r\n    elif isinstance(vision_cfg.layers, (tuple, list)):\r\n        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width\r\n        visual = ModifiedResNet(\r\n            layers=vision_cfg.layers,\r\n            output_dim=embed_dim,\r\n            heads=vision_heads,\r\n            image_size=vision_cfg.image_size,\r\n            width=vision_cfg.width\r\n        )\r\n    else:\r\n        vision_heads = vision_cfg.width // vision_cfg.head_width\r\n        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\r\n        visual = VisionTransformer(\r\n            image_size=vision_cfg.image_size,\r\n            patch_size=vision_cfg.patch_size,\r\n            width=vision_cfg.width,\r\n            layers=vision_cfg.layers,\r\n            heads=vision_heads,\r\n            mlp_ratio=vision_cfg.mlp_ratio,\r\n            ls_init_value=vision_cfg.ls_init_value,\r\n            patch_dropout=vision_cfg.patch_dropout,\r\n            global_average_pool=vision_cfg.global_average_pool,\r\n            output_dim=embed_dim,\r\n            act_layer=act_layer,\r\n            norm_layer=norm_layer,\r\n        )\r\n\r\n    return visual\r\n\r\n\r\ndef _build_text_tower(\r\n        embed_dim: int,\r\n        text_cfg: CLIPTextCfg,\r\n        quick_gelu: bool = False,\r\n        cast_dtype: Optional[torch.dtype] = None,\r\n):\r\n    if isinstance(text_cfg, dict):\r\n        text_cfg = CLIPTextCfg(**text_cfg)\r\n\r\n    if text_cfg.hf_model_name:\r\n        text = HFTextEncoder(\r\n            text_cfg.hf_model_name,\r\n            output_dim=embed_dim,\r\n            tokenizer_name=text_cfg.hf_tokenizer_name,\r\n            proj=text_cfg.proj,\r\n            pooler_type=text_cfg.pooler_type,\r\n            masked_language_modeling=text_cfg.masked_language_modeling\r\n       )\r\n    else:\r\n        act_layer = QuickGELU if quick_gelu else nn.GELU\r\n        norm_layer = LayerNorm\r\n\r\n        text = TextTransformer(\r\n            context_length=text_cfg.context_length,\r\n            vocab_size=text_cfg.vocab_size,\r\n            width=text_cfg.width,\r\n            heads=text_cfg.heads,\r\n            layers=text_cfg.layers,\r\n            ls_init_value=text_cfg.ls_init_value,\r\n            output_dim=embed_dim,\r\n            act_layer=act_layer,\r\n            norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,\r\n            xattn=text_cfg.xattn,\r\n            attn_mask=text_cfg.attn_mask,\r\n        )\r\n    return text\r\n\r\nclass CLIP(nn.Module):\r\n    def __init__(\r\n            self,\r\n            embed_dim: int,\r\n            vision_cfg: CLIPVisionCfg,\r\n            text_cfg: CLIPTextCfg,\r\n            quick_gelu: bool = False,\r\n            cast_dtype: Optional[torch.dtype] = None,\r\n    ):\r\n        super().__init__()\r\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\r\n\r\n        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\r\n        self.transformer = text.transformer\r\n        self.vocab_size = text.vocab_size\r\n        self.token_embedding = text.token_embedding\r\n        self.positional_embedding = text.positional_embedding\r\n        self.ln_final = text.ln_final\r\n        self.text_projection = text.text_projection\r\n        self.register_buffer('attn_mask', text.attn_mask, persistent=False)\r\n\r\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\r\n\r\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\r\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\r\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        self.visual.set_grad_checkpointing(enable)\r\n        self.transformer.grad_checkpointing = enable\r\n    \r\n    @torch.jit.ignore\r\n    def no_weight_decay(self):\r\n        return {'logit_scale'}\r\n\r\n    def encode_image(self, image, normalize: bool = False):\r\n        features = self.visual(image)\r\n        return F.normalize(features, dim=-1) if normalize else features\r\n\r\n    def encode_text(self, text, normalize: bool = False):\r\n        cast_dtype = self.transformer.get_cast_dtype()\r\n\r\n        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]\r\n\r\n        x = x + self.positional_embedding.to(cast_dtype)\r\n        x = x.permute(1, 0, 2)  # NLD -> LND\r\n        x = self.transformer(x, attn_mask=self.attn_mask)\r\n        x = x.permute(1, 0, 2)  # LND -> NLD\r\n        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]\r\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\r\n        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\r\n        return F.normalize(x, dim=-1) if normalize else x\r\n\r\n    def forward(self, image, text):\r\n        image_features = self.encode_image(image, normalize=True)\r\n        text_features = self.encode_text(text, normalize=True)\r\n        return image_features, text_features, self.logit_scale.exp()\r\n\r\n\r\nclass CustomCLIP(nn.Module):\r\n    def __init__(\r\n            self,\r\n            embed_dim: int,\r\n            vision_cfg: CLIPVisionCfg,\r\n            text_cfg: CLIPTextCfg,\r\n            quick_gelu: bool = False,\r\n            cast_dtype: Optional[torch.dtype] = None,\r\n            itm_task: bool = False,\r\n    ):\r\n        super().__init__()\r\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\r\n        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\r\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\r\n\r\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\r\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\r\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\r\n\r\n    def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):\r\n        self.text.lock(unlocked_layers, freeze_layer_norm)\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        self.visual.set_grad_checkpointing(enable)\r\n        self.text.set_grad_checkpointing(enable)\r\n\r\n    @torch.jit.ignore\r\n    def no_weight_decay(self):\r\n        return {'logit_scale'}\r\n\r\n    def encode_image(self, image, normalize: bool = False):\r\n        features = self.visual(image)\r\n        return F.normalize(features, dim=-1) if normalize else features\r\n\r\n    def encode_text(self, text, normalize: bool = False):\r\n        features = self.text(text)\r\n        return F.normalize(features, dim=-1) if normalize else features\r\n\r\n    def forward(self, image, text):\r\n        image_features = self.encode_image(image, normalize=True)\r\n        text_features = self.encode_text(text, normalize=True)\r\n        return image_features, text_features, self.logit_scale.exp()\r\n\r\n\r\ndef convert_weights_to_lp(model: nn.Module, dtype=torch.float16):\r\n    \"\"\"Convert applicable model parameters to low-precision (bf16 or fp16)\"\"\"\r\n\r\n    def _convert_weights(l):\r\n        \r\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\r\n            l.weight.data = l.weight.data.to(dtype)\r\n            if l.bias is not None:\r\n                l.bias.data = l.bias.data.to(dtype)\r\n\r\n        if isinstance(l, (nn.MultiheadAttention, Attention)):\r\n            for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\r\n                tensor = getattr(l, attr, None)\r\n                if tensor is not None:\r\n                    tensor.data = tensor.data.to(dtype)\r\n\r\n        if isinstance(l, nn.Parameter):\r\n            l.data = l.data.to(dtype)\r\n\r\n        for name in [\"text_projection\", \"proj\"]:\r\n            if hasattr(l, name) and isinstance(l, nn.Parameter):\r\n                attr = getattr(l, name, None)\r\n                if attr is not None:\r\n                    attr.data = attr.data.to(dtype)\r\n\r\n    model.apply(_convert_weights)\r\n\r\n\r\nconvert_weights_to_fp16 = convert_weights_to_lp  # backwards compat\r\n\r\n\r\n# used to maintain checkpoint compatibility\r\ndef convert_to_custom_text_state_dict(state_dict: dict):\r\n    if 'text_projection' in state_dict:\r\n        # old format state_dict, move text tower -> .text\r\n        new_state_dict = {}\r\n        for k, v in state_dict.items():\r\n            if any(k.startswith(p) for p in (\r\n                'text_projection',\r\n                'positional_embedding',\r\n                'token_embedding',\r\n                'transformer',\r\n                'ln_final',\r\n                'logit_scale'\r\n            )):\r\n                k = 'text.' + k\r\n            new_state_dict[k] = v\r\n        return new_state_dict\r\n    return state_dict\r\n\r\n\r\ndef build_model_from_openai_state_dict(\r\n        state_dict: dict,\r\n        quick_gelu=True,\r\n        cast_dtype=torch.float16,\r\n):\r\n    vit = \"visual.proj\" in state_dict\r\n\r\n    if vit:\r\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\r\n        vision_layers = len(\r\n            [k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\r\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\r\n        grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\r\n        image_size = vision_patch_size * grid_size\r\n    else:\r\n        counts: list = [\r\n            len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\r\n        vision_layers = tuple(counts)\r\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\r\n        output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\r\n        vision_patch_size = None\r\n        assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\r\n        image_size = output_width * 32\r\n\r\n    embed_dim = state_dict[\"text_projection\"].shape[1]\r\n    context_length = state_dict[\"positional_embedding\"].shape[0]\r\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\r\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\r\n    transformer_heads = transformer_width // 64\r\n    transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\r\n\r\n    vision_cfg = CLIPVisionCfg(\r\n        layers=vision_layers,\r\n        width=vision_width,\r\n        patch_size=vision_patch_size,\r\n        image_size=image_size,\r\n    )\r\n    text_cfg = CLIPTextCfg(\r\n        context_length=context_length,\r\n        vocab_size=vocab_size,\r\n        width=transformer_width,\r\n        heads=transformer_heads,\r\n        layers=transformer_layers\r\n    )\r\n    model = CLIP(\r\n        embed_dim,\r\n        vision_cfg=vision_cfg,\r\n        text_cfg=text_cfg,\r\n        quick_gelu=quick_gelu,  # OpenAI models were trained with QuickGELU\r\n        cast_dtype=cast_dtype,\r\n    )\r\n\r\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\r\n        state_dict.pop(key, None)\r\n\r\n    convert_weights_to_fp16(model)  # OpenAI state dicts are partially converted to float16\r\n    model.load_state_dict(state_dict)\r\n    return model.eval()\r\n\r\n\r\ndef trace_model(model, batch_size=256, device=torch.device('cpu')):\r\n    model.eval()\r\n    image_size = model.visual.image_size\r\n    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)\r\n    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)\r\n    model = torch.jit.trace_module(\r\n        model,\r\n        inputs=dict(\r\n            forward=(example_images, example_text),\r\n            encode_text=(example_text,),\r\n            encode_image=(example_images,)\r\n        ))\r\n    model.visual.image_size = image_size\r\n    return model\r\n"
  },
  {
    "path": "eva_clip/model_configs/EVA01-CLIP-B-16.json",
    "content": "{\r\n    \"embed_dim\": 512,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 12,\r\n        \"width\": 768,\r\n        \"patch_size\": 16,\r\n        \"eva_model_name\": \"eva-clip-b-16\",\r\n        \"ls_init_value\": 0.1,\r\n        \"drop_path_rate\": 0.0\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 512,\r\n        \"heads\": 8,\r\n        \"layers\": 12\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/model_configs/EVA01-CLIP-g-14-plus.json",
    "content": "{\r\n    \"embed_dim\": 1024,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 40,\r\n        \"width\": 1408,\r\n        \"head_width\": 88,\r\n        \"mlp_ratio\": 4.3637,\r\n        \"patch_size\": 14,\r\n        \"eva_model_name\": \"eva-clip-g-14-x\",\r\n        \"drop_path_rate\": 0,\r\n        \"xattn\": true,\r\n        \"fusedLN\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 1024,\r\n        \"heads\": 16,\r\n        \"layers\": 24,\r\n        \"xattn\": false,\r\n        \"fusedLN\": true\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/model_configs/EVA01-CLIP-g-14.json",
    "content": "{\r\n    \"embed_dim\": 1024,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 40,\r\n        \"width\": 1408,\r\n        \"head_width\": 88,\r\n        \"mlp_ratio\": 4.3637,\r\n        \"patch_size\": 14,\r\n        \"eva_model_name\": \"eva-clip-g-14-x\",\r\n        \"drop_path_rate\": 0.4,\r\n        \"xattn\": true,\r\n        \"fusedLN\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 768,\r\n        \"heads\": 12,\r\n        \"layers\": 12,\r\n        \"xattn\": false,\r\n        \"fusedLN\": true\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/model_configs/EVA02-CLIP-B-16.json",
    "content": "{\r\n    \"embed_dim\": 512,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 12,\r\n        \"width\": 768,\r\n        \"head_width\": 64,\r\n        \"patch_size\": 16,\r\n        \"mlp_ratio\": 2.6667,\r\n        \"eva_model_name\": \"eva-clip-b-16-X\",\r\n        \"drop_path_rate\": 0.0,\r\n        \"xattn\": true,\r\n        \"fusedLN\": true,\r\n        \"rope\": true,\r\n        \"pt_hw_seq_len\": 16,\r\n        \"intp_freq\": true,\r\n        \"naiveswiglu\": true,\r\n        \"subln\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 512,\r\n        \"heads\": 8,\r\n        \"layers\": 12,\r\n        \"xattn\": true,\r\n        \"fusedLN\": true\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/model_configs/EVA02-CLIP-L-14-336.json",
    "content": "{\r\n    \"embed_dim\": 768,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 336,\r\n        \"layers\": 24,\r\n        \"width\": 1024,\r\n        \"drop_path_rate\": 0,\r\n        \"head_width\": 64,\r\n        \"mlp_ratio\": 2.6667,\r\n        \"patch_size\": 14,\r\n        \"eva_model_name\": \"eva-clip-l-14-336\",\r\n        \"xattn\": true,\r\n        \"fusedLN\": true,\r\n        \"rope\": true,\r\n        \"pt_hw_seq_len\": 16,\r\n        \"intp_freq\": true,\r\n        \"naiveswiglu\": true,\r\n        \"subln\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 768,\r\n        \"heads\": 12,\r\n        \"layers\": 12,\r\n        \"xattn\": false,\r\n        \"fusedLN\": true\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/model_configs/EVA02-CLIP-L-14.json",
    "content": "{\r\n    \"embed_dim\": 768,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 24,\r\n        \"width\": 1024,\r\n        \"drop_path_rate\": 0,\r\n        \"head_width\": 64,\r\n        \"mlp_ratio\": 2.6667,\r\n        \"patch_size\": 14,\r\n        \"eva_model_name\": \"eva-clip-l-14\",\r\n        \"xattn\": true,\r\n        \"fusedLN\": true,\r\n        \"rope\": true,\r\n        \"pt_hw_seq_len\": 16,\r\n        \"intp_freq\": true,\r\n        \"naiveswiglu\": true,\r\n        \"subln\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 768,\r\n        \"heads\": 12,\r\n        \"layers\": 12,\r\n        \"xattn\": false,\r\n        \"fusedLN\": true\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json",
    "content": "{\r\n    \"embed_dim\": 1024,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 64,\r\n        \"width\": 1792,\r\n        \"head_width\": 112,\r\n        \"mlp_ratio\": 8.571428571428571,\r\n        \"patch_size\": 14,\r\n        \"eva_model_name\": \"eva-clip-4b-14-x\",\r\n        \"drop_path_rate\": 0,\r\n        \"xattn\": true,\r\n        \"postnorm\": true,\r\n        \"fusedLN\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 1280,\r\n        \"heads\": 20,\r\n        \"layers\": 32,\r\n        \"xattn\": false,\r\n        \"fusedLN\": true\r\n    }\r\n}\r\n"
  },
  {
    "path": "eva_clip/model_configs/EVA02-CLIP-bigE-14.json",
    "content": "{\r\n    \"embed_dim\": 1024,\r\n    \"vision_cfg\": {\r\n        \"image_size\": 224,\r\n        \"layers\": 64,\r\n        \"width\": 1792,\r\n        \"head_width\": 112,\r\n        \"mlp_ratio\": 8.571428571428571,\r\n        \"patch_size\": 14,\r\n        \"eva_model_name\": \"eva-clip-4b-14-x\",\r\n        \"drop_path_rate\": 0,\r\n        \"xattn\": true,\r\n        \"postnorm\": true,\r\n        \"fusedLN\": true\r\n    },\r\n    \"text_cfg\": {\r\n        \"context_length\": 77,\r\n        \"vocab_size\": 49408,\r\n        \"width\": 1024,\r\n        \"heads\": 16,\r\n        \"layers\": 24,\r\n        \"xattn\": false,\r\n        \"fusedLN\": true\r\n    }\r\n}"
  },
  {
    "path": "eva_clip/modified_resnet.py",
    "content": "from collections import OrderedDict\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\nfrom .utils import freeze_batch_norm_2d\r\n\r\n\r\nclass Bottleneck(nn.Module):\r\n    expansion = 4\r\n\r\n    def __init__(self, inplanes, planes, stride=1):\r\n        super().__init__()\r\n\r\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\r\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(planes)\r\n        self.act1 = nn.ReLU(inplace=True)\r\n\r\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\r\n        self.bn2 = nn.BatchNorm2d(planes)\r\n        self.act2 = nn.ReLU(inplace=True)\r\n\r\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\r\n\r\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\r\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\r\n        self.act3 = nn.ReLU(inplace=True)\r\n\r\n        self.downsample = None\r\n        self.stride = stride\r\n\r\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\r\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\r\n            self.downsample = nn.Sequential(OrderedDict([\r\n                (\"-1\", nn.AvgPool2d(stride)),\r\n                (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\r\n                (\"1\", nn.BatchNorm2d(planes * self.expansion))\r\n            ]))\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        identity = x\r\n\r\n        out = self.act1(self.bn1(self.conv1(x)))\r\n        out = self.act2(self.bn2(self.conv2(out)))\r\n        out = self.avgpool(out)\r\n        out = self.bn3(self.conv3(out))\r\n\r\n        if self.downsample is not None:\r\n            identity = self.downsample(x)\r\n\r\n        out += identity\r\n        out = self.act3(out)\r\n        return out\r\n\r\n\r\nclass AttentionPool2d(nn.Module):\r\n    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\r\n        super().__init__()\r\n        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\r\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\r\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\r\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\r\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\r\n        self.num_heads = num_heads\r\n\r\n    def forward(self, x):\r\n        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC\r\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\r\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\r\n        x, _ = F.multi_head_attention_forward(\r\n            query=x, key=x, value=x,\r\n            embed_dim_to_check=x.shape[-1],\r\n            num_heads=self.num_heads,\r\n            q_proj_weight=self.q_proj.weight,\r\n            k_proj_weight=self.k_proj.weight,\r\n            v_proj_weight=self.v_proj.weight,\r\n            in_proj_weight=None,\r\n            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\r\n            bias_k=None,\r\n            bias_v=None,\r\n            add_zero_attn=False,\r\n            dropout_p=0.,\r\n            out_proj_weight=self.c_proj.weight,\r\n            out_proj_bias=self.c_proj.bias,\r\n            use_separate_proj_weight=True,\r\n            training=self.training,\r\n            need_weights=False\r\n        )\r\n\r\n        return x[0]\r\n\r\n\r\nclass ModifiedResNet(nn.Module):\r\n    \"\"\"\r\n    A ResNet class that is similar to torchvision's but contains the following changes:\r\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\r\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\r\n    - The final pooling layer is a QKV attention instead of an average pool\r\n    \"\"\"\r\n\r\n    def __init__(self, layers, output_dim, heads, image_size=224, width=64):\r\n        super().__init__()\r\n        self.output_dim = output_dim\r\n        self.image_size = image_size\r\n\r\n        # the 3-layer stem\r\n        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(width // 2)\r\n        self.act1 = nn.ReLU(inplace=True)\r\n        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\r\n        self.bn2 = nn.BatchNorm2d(width // 2)\r\n        self.act2 = nn.ReLU(inplace=True)\r\n        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\r\n        self.bn3 = nn.BatchNorm2d(width)\r\n        self.act3 = nn.ReLU(inplace=True)\r\n        self.avgpool = nn.AvgPool2d(2)\r\n\r\n        # residual layers\r\n        self._inplanes = width  # this is a *mutable* variable used during construction\r\n        self.layer1 = self._make_layer(width, layers[0])\r\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\r\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\r\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\r\n\r\n        embed_dim = width * 32  # the ResNet feature dimension\r\n        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)\r\n\r\n        self.init_parameters()\r\n\r\n    def _make_layer(self, planes, blocks, stride=1):\r\n        layers = [Bottleneck(self._inplanes, planes, stride)]\r\n\r\n        self._inplanes = planes * Bottleneck.expansion\r\n        for _ in range(1, blocks):\r\n            layers.append(Bottleneck(self._inplanes, planes))\r\n\r\n        return nn.Sequential(*layers)\r\n\r\n    def init_parameters(self):\r\n        if self.attnpool is not None:\r\n            std = self.attnpool.c_proj.in_features ** -0.5\r\n            nn.init.normal_(self.attnpool.q_proj.weight, std=std)\r\n            nn.init.normal_(self.attnpool.k_proj.weight, std=std)\r\n            nn.init.normal_(self.attnpool.v_proj.weight, std=std)\r\n            nn.init.normal_(self.attnpool.c_proj.weight, std=std)\r\n\r\n        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:\r\n            for name, param in resnet_block.named_parameters():\r\n                if name.endswith(\"bn3.weight\"):\r\n                    nn.init.zeros_(param)\r\n\r\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\r\n        assert unlocked_groups == 0, 'partial locking not currently supported for this model'\r\n        for param in self.parameters():\r\n            param.requires_grad = False\r\n        if freeze_bn_stats:\r\n            freeze_batch_norm_2d(self)\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        # FIXME support for non-transformer\r\n        pass\r\n\r\n    def stem(self, x):\r\n        x = self.act1(self.bn1(self.conv1(x)))\r\n        x = self.act2(self.bn2(self.conv2(x)))\r\n        x = self.act3(self.bn3(self.conv3(x)))\r\n        x = self.avgpool(x)\r\n        return x\r\n\r\n    def forward(self, x):\r\n        x = self.stem(x)\r\n        x = self.layer1(x)\r\n        x = self.layer2(x)\r\n        x = self.layer3(x)\r\n        x = self.layer4(x)\r\n        x = self.attnpool(x)\r\n\r\n        return x\r\n"
  },
  {
    "path": "eva_clip/openai.py",
    "content": "\"\"\" OpenAI pretrained model functions\r\n\r\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\r\n\"\"\"\r\n\r\nimport os\r\nimport warnings\r\nfrom typing import List, Optional, Union\r\n\r\nimport torch\r\n\r\nfrom .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype\r\nfrom .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url\r\n\r\n__all__ = [\"list_openai_models\", \"load_openai_model\"]\r\n\r\n\r\ndef list_openai_models() -> List[str]:\r\n    \"\"\"Returns the names of available CLIP models\"\"\"\r\n    return list_pretrained_models_by_tag('openai')\r\n\r\n\r\ndef load_openai_model(\r\n        name: str,\r\n        precision: Optional[str] = None,\r\n        device: Optional[Union[str, torch.device]] = None,\r\n        jit: bool = True,\r\n        cache_dir: Optional[str] = None,\r\n):\r\n    \"\"\"Load a CLIP model\r\n\r\n    Parameters\r\n    ----------\r\n    name : str\r\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\r\n    precision: str\r\n        Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.\r\n    device : Union[str, torch.device]\r\n        The device to put the loaded model\r\n    jit : bool\r\n        Whether to load the optimized JIT model (default) or more hackable non-JIT model.\r\n    cache_dir : Optional[str]\r\n        The directory to cache the downloaded model weights\r\n\r\n    Returns\r\n    -------\r\n    model : torch.nn.Module\r\n        The CLIP model\r\n    preprocess : Callable[[PIL.Image], torch.Tensor]\r\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\r\n    \"\"\"\r\n    if device is None:\r\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\r\n    if precision is None:\r\n        precision = 'fp32' if device == 'cpu' else 'fp16'\r\n\r\n    if get_pretrained_url(name, 'openai'):\r\n        model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)\r\n    elif os.path.isfile(name):\r\n        model_path = name\r\n    else:\r\n        raise RuntimeError(f\"Model {name} not found; available models = {list_openai_models()}\")\r\n\r\n    try:\r\n        # loading JIT archive\r\n        model = torch.jit.load(model_path, map_location=device if jit else \"cpu\").eval()\r\n        state_dict = None\r\n    except RuntimeError:\r\n        # loading saved state dict\r\n        if jit:\r\n            warnings.warn(f\"File {model_path} is not a JIT archive. Loading as a state dict instead\")\r\n            jit = False\r\n        state_dict = torch.load(model_path, map_location=\"cpu\")\r\n\r\n    if not jit:\r\n        # Build a non-jit model from the OpenAI jitted model state dict\r\n        cast_dtype = get_cast_dtype(precision)\r\n        try:\r\n            model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)\r\n        except KeyError:\r\n            sd = {k[7:]: v for k, v in state_dict[\"state_dict\"].items()}\r\n            model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)\r\n\r\n        # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use\r\n        model = model.to(device)\r\n        if precision.startswith('amp') or precision == 'fp32':\r\n            model.float()\r\n        elif precision == 'bf16':\r\n            convert_weights_to_lp(model, dtype=torch.bfloat16)\r\n\r\n        return model\r\n\r\n    # patch the device names\r\n    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])\r\n    device_node = [n for n in device_holder.graph.findAllNodes(\"prim::Constant\") if \"Device\" in repr(n)][-1]\r\n\r\n    def patch_device(module):\r\n        try:\r\n            graphs = [module.graph] if hasattr(module, \"graph\") else []\r\n        except RuntimeError:\r\n            graphs = []\r\n\r\n        if hasattr(module, \"forward1\"):\r\n            graphs.append(module.forward1.graph)\r\n\r\n        for graph in graphs:\r\n            for node in graph.findAllNodes(\"prim::Constant\"):\r\n                if \"value\" in node.attributeNames() and str(node[\"value\"]).startswith(\"cuda\"):\r\n                    node.copyAttributes(device_node)\r\n\r\n    model.apply(patch_device)\r\n    patch_device(model.encode_image)\r\n    patch_device(model.encode_text)\r\n\r\n    # patch dtype to float32 (typically for CPU)\r\n    if precision == 'fp32':\r\n        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])\r\n        float_input = list(float_holder.graph.findNode(\"aten::to\").inputs())[1]\r\n        float_node = float_input.node()\r\n\r\n        def patch_float(module):\r\n            try:\r\n                graphs = [module.graph] if hasattr(module, \"graph\") else []\r\n            except RuntimeError:\r\n                graphs = []\r\n\r\n            if hasattr(module, \"forward1\"):\r\n                graphs.append(module.forward1.graph)\r\n\r\n            for graph in graphs:\r\n                for node in graph.findAllNodes(\"aten::to\"):\r\n                    inputs = list(node.inputs())\r\n                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()\r\n                        if inputs[i].node()[\"value\"] == 5:\r\n                            inputs[i].node().copyAttributes(float_node)\r\n\r\n        model.apply(patch_float)\r\n        patch_float(model.encode_image)\r\n        patch_float(model.encode_text)\r\n        model.float()\r\n\r\n    # ensure image_size attr available at consistent location for both jit and non-jit\r\n    model.visual.image_size = model.input_resolution.item()\r\n    return model\r\n"
  },
  {
    "path": "eva_clip/pretrained.py",
    "content": "import hashlib\r\nimport os\r\nimport urllib\r\nimport warnings\r\nfrom functools import partial\r\nfrom typing import Dict, Union\r\n\r\nfrom tqdm import tqdm\r\n\r\ntry:\r\n    from huggingface_hub import hf_hub_download\r\n    _has_hf_hub = True\r\nexcept ImportError:\r\n    hf_hub_download = None\r\n    _has_hf_hub = False\r\n\r\n\r\ndef _pcfg(url='', hf_hub='', filename='', mean=None, std=None):\r\n    return dict(\r\n        url=url,\r\n        hf_hub=hf_hub,\r\n        mean=mean,\r\n        std=std,\r\n    )\r\n\r\n_VITB32 = dict(\r\n    openai=_pcfg(\r\n        \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\"),\r\n    laion400m_e31=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\"),\r\n    laion400m_e32=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\"),\r\n    laion2b_e16=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth\"),\r\n    laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')\r\n)\r\n\r\n_VITB32_quickgelu = dict(\r\n    openai=_pcfg(\r\n        \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\"),\r\n    laion400m_e31=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\"),\r\n    laion400m_e32=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\"),\r\n)\r\n\r\n_VITB16 = dict(\r\n    openai=_pcfg(\r\n        \"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\"),\r\n    laion400m_e31=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt\"),\r\n    laion400m_e32=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt\"),\r\n    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),\r\n)\r\n\r\n_EVAB16 = dict(\r\n    eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),\r\n    eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),\r\n    eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),\r\n)\r\n\r\n_VITB16_PLUS_240 = dict(\r\n    laion400m_e31=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt\"),\r\n    laion400m_e32=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt\"),\r\n)\r\n\r\n_VITL14 = dict(\r\n    openai=_pcfg(\r\n        \"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\"),\r\n    laion400m_e31=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt\"),\r\n    laion400m_e32=_pcfg(\r\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt\"),\r\n    laion2b_s32b_b82k=_pcfg(\r\n        hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',\r\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),\r\n)\r\n\r\n_EVAL14 = dict(\r\n    eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),\r\n    eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),\r\n    eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),\r\n)\r\n\r\n_VITL14_336 = dict(\r\n    openai=_pcfg(\r\n        \"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\"),\r\n)\r\n\r\n_EVAL14_336 = dict(\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),\r\n    eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),\r\n    eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),\r\n    eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),\r\n)\r\n\r\n_VITH14 = dict(\r\n    laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),\r\n)\r\n\r\n_VITg14 = dict(\r\n    laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),\r\n    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),\r\n)\r\n\r\n_EVAg14 = dict(\r\n    eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),\r\n    eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),\r\n    eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),\r\n)\r\n\r\n_EVAg14_PLUS = dict(\r\n    eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),\r\n    eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),\r\n    eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),\r\n)\r\n\r\n_VITbigG14 = dict(\r\n    laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),\r\n)\r\n\r\n_EVAbigE14 = dict(\r\n    eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),\r\n    eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),\r\n    eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),\r\n)\r\n\r\n_EVAbigE14_PLUS = dict(\r\n    eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),\r\n    eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),\r\n    eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),\r\n    eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),\r\n)\r\n\r\n\r\n_PRETRAINED = {\r\n    # \"ViT-B-32\": _VITB32,\r\n    \"OpenaiCLIP-B-32\": _VITB32,\r\n    \"OpenCLIP-B-32\": _VITB32,\r\n\r\n    # \"ViT-B-32-quickgelu\": _VITB32_quickgelu,\r\n    \"OpenaiCLIP-B-32-quickgelu\": _VITB32_quickgelu,\r\n    \"OpenCLIP-B-32-quickgelu\": _VITB32_quickgelu,\r\n\r\n    # \"ViT-B-16\": _VITB16,\r\n    \"OpenaiCLIP-B-16\": _VITB16,\r\n    \"OpenCLIP-B-16\": _VITB16,\r\n\r\n    \"EVA02-B-16\": _EVAB16,\r\n    \"EVA02-CLIP-B-16\": _EVAB16,\r\n\r\n    # \"ViT-B-16-plus-240\": _VITB16_PLUS_240,\r\n    \"OpenCLIP-B-16-plus-240\": _VITB16_PLUS_240,\r\n\r\n    # \"ViT-L-14\": _VITL14,\r\n    \"OpenaiCLIP-L-14\": _VITL14,\r\n    \"OpenCLIP-L-14\": _VITL14,\r\n\r\n    \"EVA02-L-14\": _EVAL14,\r\n    \"EVA02-CLIP-L-14\": _EVAL14,\r\n\r\n    # \"ViT-L-14-336\": _VITL14_336,\r\n    \"OpenaiCLIP-L-14-336\": _VITL14_336,\r\n\r\n    \"EVA02-CLIP-L-14-336\": _EVAL14_336,\r\n\r\n    # \"ViT-H-14\": _VITH14,\r\n    # \"ViT-g-14\": _VITg14,\r\n    \"OpenCLIP-H-14\": _VITH14,\r\n    \"OpenCLIP-g-14\": _VITg14,\r\n\r\n    \"EVA01-CLIP-g-14\": _EVAg14,\r\n    \"EVA01-CLIP-g-14-plus\": _EVAg14_PLUS,\r\n\r\n    # \"ViT-bigG-14\": _VITbigG14,\r\n    \"OpenCLIP-bigG-14\": _VITbigG14,\r\n\r\n    \"EVA02-CLIP-bigE-14\": _EVAbigE14,\r\n    \"EVA02-CLIP-bigE-14-plus\": _EVAbigE14_PLUS,\r\n}\r\n\r\n\r\ndef _clean_tag(tag: str):\r\n    # normalize pretrained tags\r\n    return tag.lower().replace('-', '_')\r\n\r\n\r\ndef list_pretrained(as_str: bool = False):\r\n    \"\"\" returns list of pretrained models\r\n    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True\r\n    \"\"\"\r\n    return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]\r\n\r\n\r\ndef list_pretrained_models_by_tag(tag: str):\r\n    \"\"\" return all models having the specified pretrain tag \"\"\"\r\n    models = []\r\n    tag = _clean_tag(tag)\r\n    for k in _PRETRAINED.keys():\r\n        if tag in _PRETRAINED[k]:\r\n            models.append(k)\r\n    return models\r\n\r\n\r\ndef list_pretrained_tags_by_model(model: str):\r\n    \"\"\" return all pretrain tags for the specified model architecture \"\"\"\r\n    tags = []\r\n    if model in _PRETRAINED:\r\n        tags.extend(_PRETRAINED[model].keys())\r\n    return tags\r\n\r\n\r\ndef is_pretrained_cfg(model: str, tag: str):\r\n    if model not in _PRETRAINED:\r\n        return False\r\n    return _clean_tag(tag) in _PRETRAINED[model]\r\n\r\n\r\ndef get_pretrained_cfg(model: str, tag: str):\r\n    if model not in _PRETRAINED:\r\n        return {}\r\n    model_pretrained = _PRETRAINED[model]\r\n    return model_pretrained.get(_clean_tag(tag), {})\r\n\r\n\r\ndef get_pretrained_url(model: str, tag: str):\r\n    cfg = get_pretrained_cfg(model, _clean_tag(tag))\r\n    return cfg.get('url', '')\r\n\r\n\r\ndef download_pretrained_from_url(\r\n        url: str,\r\n        cache_dir: Union[str, None] = None,\r\n):\r\n    if not cache_dir:\r\n        cache_dir = os.path.expanduser(\"~/.cache/clip\")\r\n    os.makedirs(cache_dir, exist_ok=True)\r\n    filename = os.path.basename(url)\r\n\r\n    if 'openaipublic' in url:\r\n        expected_sha256 = url.split(\"/\")[-2]\r\n    elif 'mlfoundations' in url:\r\n        expected_sha256 = os.path.splitext(filename)[0].split(\"-\")[-1]\r\n    else:\r\n        expected_sha256 = ''\r\n\r\n    download_target = os.path.join(cache_dir, filename)\r\n\r\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\r\n        raise RuntimeError(f\"{download_target} exists and is not a regular file\")\r\n\r\n    if os.path.isfile(download_target):\r\n        if expected_sha256:\r\n            if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\r\n                return download_target\r\n            else:\r\n                warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\r\n        else:\r\n            return download_target\r\n\r\n    with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\r\n        with tqdm(total=int(source.headers.get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True) as loop:\r\n            while True:\r\n                buffer = source.read(8192)\r\n                if not buffer:\r\n                    break\r\n\r\n                output.write(buffer)\r\n                loop.update(len(buffer))\r\n\r\n    if expected_sha256 and not hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\r\n        raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\r\n\r\n    return download_target\r\n\r\n\r\ndef has_hf_hub(necessary=False):\r\n    if not _has_hf_hub and necessary:\r\n        # if no HF Hub module installed, and it is necessary to continue, raise error\r\n        raise RuntimeError(\r\n            'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')\r\n    return _has_hf_hub\r\n\r\n\r\ndef download_pretrained_from_hf(\r\n        model_id: str,\r\n        filename: str = 'open_clip_pytorch_model.bin',\r\n        revision=None,\r\n        cache_dir: Union[str, None] = None,\r\n):\r\n    has_hf_hub(True)\r\n    cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)\r\n    return cached_file\r\n\r\n\r\ndef download_pretrained(\r\n        cfg: Dict,\r\n        force_hf_hub: bool = False,\r\n        cache_dir: Union[str, None] = None,\r\n):\r\n    target = ''\r\n    if not cfg:\r\n        return target\r\n\r\n    download_url = cfg.get('url', '')\r\n    download_hf_hub = cfg.get('hf_hub', '')\r\n    if download_hf_hub and force_hf_hub:\r\n        # use HF hub even if url exists\r\n        download_url = ''\r\n\r\n    if download_url:\r\n        target = download_pretrained_from_url(download_url, cache_dir=cache_dir)\r\n    elif download_hf_hub:\r\n        has_hf_hub(True)\r\n        # we assume the hf_hub entries in pretrained config combine model_id + filename in\r\n        # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and\r\n        # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.\r\n        model_id, filename = os.path.split(download_hf_hub)\r\n        if filename:\r\n            target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)\r\n        else:\r\n            target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\r\n\r\n    return target\r\n"
  },
  {
    "path": "eva_clip/rope.py",
    "content": "from math import pi\r\nimport torch\r\nfrom torch import nn\r\nfrom einops import rearrange, repeat\r\nimport logging\r\n\r\ndef broadcat(tensors, dim = -1):\r\n    num_tensors = len(tensors)\r\n    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))\r\n    assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'\r\n    shape_len = list(shape_lens)[0]\r\n    dim = (dim + shape_len) if dim < 0 else dim\r\n    dims = list(zip(*map(lambda t: list(t.shape), tensors)))\r\n    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]\r\n    assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'\r\n    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))\r\n    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))\r\n    expanded_dims.insert(dim, (dim, dims[dim]))\r\n    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))\r\n    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))\r\n    return torch.cat(tensors, dim = dim)\r\n\r\ndef rotate_half(x):\r\n    x = rearrange(x, '... (d r) -> ... d r', r = 2)\r\n    x1, x2 = x.unbind(dim = -1)\r\n    x = torch.stack((-x2, x1), dim = -1)\r\n    return rearrange(x, '... d r -> ... (d r)')\r\n\r\n\r\nclass VisionRotaryEmbedding(nn.Module):\r\n    def __init__(\r\n        self,\r\n        dim,\r\n        pt_seq_len,\r\n        ft_seq_len=None,\r\n        custom_freqs = None,\r\n        freqs_for = 'lang',\r\n        theta = 10000,\r\n        max_freq = 10,\r\n        num_freqs = 1,\r\n    ):\r\n        super().__init__()\r\n        if custom_freqs:\r\n            freqs = custom_freqs\r\n        elif freqs_for == 'lang':\r\n            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))\r\n        elif freqs_for == 'pixel':\r\n            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi\r\n        elif freqs_for == 'constant':\r\n            freqs = torch.ones(num_freqs).float()\r\n        else:\r\n            raise ValueError(f'unknown modality {freqs_for}')\r\n\r\n        if ft_seq_len is None: ft_seq_len = pt_seq_len\r\n        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len\r\n\r\n        freqs_h = torch.einsum('..., f -> ... f', t, freqs)\r\n        freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)\r\n\r\n        freqs_w = torch.einsum('..., f -> ... f', t, freqs)\r\n        freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)\r\n\r\n        freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) \r\n\r\n        self.register_buffer(\"freqs_cos\", freqs.cos())\r\n        self.register_buffer(\"freqs_sin\", freqs.sin())\r\n\r\n        logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')\r\n\r\n    def forward(self, t, start_index = 0):\r\n        rot_dim = self.freqs_cos.shape[-1]\r\n        end_index = start_index + rot_dim\r\n        assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'\r\n        t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]\r\n        t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)\r\n\r\n        return torch.cat((t_left, t, t_right), dim = -1)\r\n\r\nclass VisionRotaryEmbeddingFast(nn.Module):\r\n    def __init__(\r\n        self,\r\n        dim,\r\n        pt_seq_len,\r\n        ft_seq_len=None,\r\n        custom_freqs = None,\r\n        freqs_for = 'lang',\r\n        theta = 10000,\r\n        max_freq = 10,\r\n        num_freqs = 1,\r\n        patch_dropout = 0.\r\n    ):\r\n        super().__init__()\r\n        if custom_freqs:\r\n            freqs = custom_freqs\r\n        elif freqs_for == 'lang':\r\n            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))\r\n        elif freqs_for == 'pixel':\r\n            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi\r\n        elif freqs_for == 'constant':\r\n            freqs = torch.ones(num_freqs).float()\r\n        else:\r\n            raise ValueError(f'unknown modality {freqs_for}')\r\n\r\n        if ft_seq_len is None: ft_seq_len = pt_seq_len\r\n        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len\r\n\r\n        freqs = torch.einsum('..., f -> ... f', t, freqs)\r\n        freqs = repeat(freqs, '... n -> ... (n r)', r = 2)\r\n        freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)\r\n\r\n        freqs_cos = freqs.cos().view(-1, freqs.shape[-1])\r\n        freqs_sin = freqs.sin().view(-1, freqs.shape[-1])\r\n\r\n        self.patch_dropout = patch_dropout\r\n\r\n        self.register_buffer(\"freqs_cos\", freqs_cos)\r\n        self.register_buffer(\"freqs_sin\", freqs_sin)\r\n\r\n        logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')\r\n\r\n    def forward(self, t, patch_indices_keep=None):\r\n        if patch_indices_keep is not None:\r\n            batch = t.size()[0]\r\n            batch_indices = torch.arange(batch)\r\n            batch_indices = batch_indices[..., None]\r\n\r\n            freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])\r\n            freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])\r\n\r\n            freqs_cos = freqs_cos[batch_indices, patch_indices_keep]\r\n            freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')\r\n            freqs_sin = freqs_sin[batch_indices, patch_indices_keep]\r\n            freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')\r\n\r\n            return  t * freqs_cos + rotate_half(t) * freqs_sin\r\n\r\n        return  t * self.freqs_cos + rotate_half(t) * self.freqs_sin"
  },
  {
    "path": "eva_clip/timm_model.py",
    "content": "\"\"\" timm model adapter\r\n\r\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.\r\n\"\"\"\r\nimport logging\r\nfrom collections import OrderedDict\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\ntry:\r\n    import timm\r\n    from timm.models.layers import Mlp, to_2tuple\r\n    try:\r\n        # old timm imports < 0.8.1\r\n        from timm.models.layers.attention_pool2d import RotAttentionPool2d\r\n        from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d\r\n    except ImportError:\r\n        # new timm imports >= 0.8.1\r\n        from timm.layers import RotAttentionPool2d\r\n        from timm.layers import AttentionPool2d as AbsAttentionPool2d\r\nexcept ImportError:\r\n    timm = None\r\n\r\nfrom .utils import freeze_batch_norm_2d\r\n\r\n\r\nclass TimmModel(nn.Module):\r\n    \"\"\" timm model adapter\r\n    # FIXME this adapter is a work in progress, may change in ways that break weight compat\r\n    \"\"\"\r\n\r\n    def __init__(\r\n            self,\r\n            model_name,\r\n            embed_dim,\r\n            image_size=224,\r\n            pool='avg',\r\n            proj='linear',\r\n            proj_bias=False,\r\n            drop=0.,\r\n            pretrained=False):\r\n        super().__init__()\r\n        if timm is None:\r\n            raise RuntimeError(\"Please `pip install timm` to use timm models.\")\r\n\r\n        self.image_size = to_2tuple(image_size)\r\n        self.trunk = timm.create_model(model_name, pretrained=pretrained)\r\n        feat_size = self.trunk.default_cfg.get('pool_size', None)\r\n        feature_ndim = 1 if not feat_size else 2\r\n        if pool in ('abs_attn', 'rot_attn'):\r\n            assert feature_ndim == 2\r\n            # if attn pooling used, remove both classifier and default pool\r\n            self.trunk.reset_classifier(0, global_pool='')\r\n        else:\r\n            # reset global pool if pool config set, otherwise leave as network default\r\n            reset_kwargs = dict(global_pool=pool) if pool else {}\r\n            self.trunk.reset_classifier(0, **reset_kwargs)\r\n        prev_chs = self.trunk.num_features\r\n\r\n        head_layers = OrderedDict()\r\n        if pool == 'abs_attn':\r\n            head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)\r\n            prev_chs = embed_dim\r\n        elif pool == 'rot_attn':\r\n            head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)\r\n            prev_chs = embed_dim\r\n        else:\r\n            assert proj, 'projection layer needed if non-attention pooling is used.'\r\n\r\n        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used\r\n        if proj == 'linear':\r\n            head_layers['drop'] = nn.Dropout(drop)\r\n            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)\r\n        elif proj == 'mlp':\r\n            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))\r\n\r\n        self.head = nn.Sequential(head_layers)\r\n\r\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\r\n        \"\"\" lock modules\r\n        Args:\r\n            unlocked_groups (int): leave last n layer groups unlocked (default: 0)\r\n        \"\"\"\r\n        if not unlocked_groups:\r\n            # lock full model\r\n            for param in self.trunk.parameters():\r\n                param.requires_grad = False\r\n            if freeze_bn_stats:\r\n                freeze_batch_norm_2d(self.trunk)\r\n        else:\r\n            # NOTE: partial freeze requires latest timm (master) branch and is subject to change\r\n            try:\r\n                # FIXME import here until API stable and in an official release\r\n                from timm.models.helpers import group_parameters, group_modules\r\n            except ImportError:\r\n                raise RuntimeError(\r\n                    'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')\r\n            matcher = self.trunk.group_matcher()\r\n            gparams = group_parameters(self.trunk, matcher)\r\n            max_layer_id = max(gparams.keys())\r\n            max_layer_id = max_layer_id - unlocked_groups\r\n            for group_idx in range(max_layer_id + 1):\r\n                group = gparams[group_idx]\r\n                for param in group:\r\n                    self.trunk.get_parameter(param).requires_grad = False\r\n            if freeze_bn_stats:\r\n                gmodules = group_modules(self.trunk, matcher, reverse=True)\r\n                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}\r\n                freeze_batch_norm_2d(self.trunk, gmodules)\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        try:\r\n            self.trunk.set_grad_checkpointing(enable)\r\n        except Exception as e:\r\n            logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')\r\n\r\n    def forward(self, x):\r\n        x = self.trunk(x)\r\n        x = self.head(x)\r\n        return x\r\n"
  },
  {
    "path": "eva_clip/tokenizer.py",
    "content": "\"\"\" CLIP tokenizer\r\n\r\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\r\n\"\"\"\r\nimport gzip\r\nimport html\r\nimport os\r\nfrom functools import lru_cache\r\nfrom typing import Union, List\r\n\r\nimport ftfy\r\nimport regex as re\r\nimport torch\r\n\r\n# https://stackoverflow.com/q/62691279\r\nimport os\r\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\r\n\r\n\r\n@lru_cache()\r\ndef default_bpe():\r\n    return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"bpe_simple_vocab_16e6.txt.gz\")\r\n\r\n\r\n@lru_cache()\r\ndef bytes_to_unicode():\r\n    \"\"\"\r\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\r\n    The reversible bpe codes work on unicode strings.\r\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\r\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\r\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\r\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\r\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\r\n    \"\"\"\r\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\r\n    cs = bs[:]\r\n    n = 0\r\n    for b in range(2**8):\r\n        if b not in bs:\r\n            bs.append(b)\r\n            cs.append(2**8+n)\r\n            n += 1\r\n    cs = [chr(n) for n in cs]\r\n    return dict(zip(bs, cs))\r\n\r\n\r\ndef get_pairs(word):\r\n    \"\"\"Return set of symbol pairs in a word.\r\n    Word is represented as tuple of symbols (symbols being variable-length strings).\r\n    \"\"\"\r\n    pairs = set()\r\n    prev_char = word[0]\r\n    for char in word[1:]:\r\n        pairs.add((prev_char, char))\r\n        prev_char = char\r\n    return pairs\r\n\r\n\r\ndef basic_clean(text):\r\n    text = ftfy.fix_text(text)\r\n    text = html.unescape(html.unescape(text))\r\n    return text.strip()\r\n\r\n\r\ndef whitespace_clean(text):\r\n    text = re.sub(r'\\s+', ' ', text)\r\n    text = text.strip()\r\n    return text\r\n\r\n\r\nclass SimpleTokenizer(object):\r\n    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):\r\n        self.byte_encoder = bytes_to_unicode()\r\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\r\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\r\n        merges = merges[1:49152-256-2+1]\r\n        merges = [tuple(merge.split()) for merge in merges]\r\n        vocab = list(bytes_to_unicode().values())\r\n        vocab = vocab + [v+'</w>' for v in vocab]\r\n        for merge in merges:\r\n            vocab.append(''.join(merge))\r\n        if not special_tokens:\r\n            special_tokens = ['<start_of_text>', '<end_of_text>']\r\n        else:\r\n            special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens\r\n        vocab.extend(special_tokens)\r\n        self.encoder = dict(zip(vocab, range(len(vocab))))\r\n        self.decoder = {v: k for k, v in self.encoder.items()}\r\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\r\n        self.cache = {t:t for t in special_tokens}\r\n        special = \"|\".join(special_tokens)\r\n        self.pat = re.compile(special + r\"\"\"|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\", re.IGNORECASE)\r\n\r\n        self.vocab_size = len(self.encoder)\r\n        self.all_special_ids = [self.encoder[t] for t in special_tokens]\r\n\r\n    def bpe(self, token):\r\n        if token in self.cache:\r\n            return self.cache[token]\r\n        word = tuple(token[:-1]) + ( token[-1] + '</w>',)\r\n        pairs = get_pairs(word)\r\n\r\n        if not pairs:\r\n            return token+'</w>'\r\n\r\n        while True:\r\n            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\r\n            if bigram not in self.bpe_ranks:\r\n                break\r\n            first, second = bigram\r\n            new_word = []\r\n            i = 0\r\n            while i < len(word):\r\n                try:\r\n                    j = word.index(first, i)\r\n                    new_word.extend(word[i:j])\r\n                    i = j\r\n                except:\r\n                    new_word.extend(word[i:])\r\n                    break\r\n\r\n                if word[i] == first and i < len(word)-1 and word[i+1] == second:\r\n                    new_word.append(first+second)\r\n                    i += 2\r\n                else:\r\n                    new_word.append(word[i])\r\n                    i += 1\r\n            new_word = tuple(new_word)\r\n            word = new_word\r\n            if len(word) == 1:\r\n                break\r\n            else:\r\n                pairs = get_pairs(word)\r\n        word = ' '.join(word)\r\n        self.cache[token] = word\r\n        return word\r\n\r\n    def encode(self, text):\r\n        bpe_tokens = []\r\n        text = whitespace_clean(basic_clean(text)).lower()\r\n        for token in re.findall(self.pat, text):\r\n            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\r\n            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\r\n        return bpe_tokens\r\n\r\n    def decode(self, tokens):\r\n        text = ''.join([self.decoder[token] for token in tokens])\r\n        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\r\n        return text\r\n\r\n\r\n_tokenizer = SimpleTokenizer()\r\n\r\n\r\ndef tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:\r\n    \"\"\"\r\n    Returns the tokenized representation of given input string(s)\r\n\r\n    Parameters\r\n    ----------\r\n    texts : Union[str, List[str]]\r\n        An input string or a list of input strings to tokenize\r\n    context_length : int\r\n        The context length to use; all CLIP models use 77 as the context length\r\n\r\n    Returns\r\n    -------\r\n    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\r\n    \"\"\"\r\n    if isinstance(texts, str):\r\n        texts = [texts]\r\n\r\n    sot_token = _tokenizer.encoder[\"<start_of_text>\"]\r\n    eot_token = _tokenizer.encoder[\"<end_of_text>\"]\r\n    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]\r\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\r\n\r\n    for i, tokens in enumerate(all_tokens):\r\n        if len(tokens) > context_length:\r\n            tokens = tokens[:context_length]  # Truncate\r\n            tokens[-1] = eot_token\r\n        result[i, :len(tokens)] = torch.tensor(tokens)\r\n\r\n    return result\r\n\r\n\r\nclass HFTokenizer:\r\n    \"HuggingFace tokenizer wrapper\"\r\n    def __init__(self, tokenizer_name:str):\r\n        from transformers import AutoTokenizer\r\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\r\n\r\n    def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:\r\n        # same cleaning as for default tokenizer, except lowercasing\r\n        # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance\r\n        if isinstance(texts, str):\r\n            texts = [texts]\r\n        texts = [whitespace_clean(basic_clean(text)) for text in texts]\r\n        input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids\r\n        return input_ids\r\n"
  },
  {
    "path": "eva_clip/transform.py",
    "content": "from typing import Optional, Sequence, Tuple\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torchvision.transforms.functional as F\r\n\r\nfrom torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\r\n    CenterCrop\r\n\r\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\r\n\r\n\r\nclass ResizeMaxSize(nn.Module):\r\n\r\n    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):\r\n        super().__init__()\r\n        if not isinstance(max_size, int):\r\n            raise TypeError(f\"Size should be int. Got {type(max_size)}\")\r\n        self.max_size = max_size\r\n        self.interpolation = interpolation\r\n        self.fn = min if fn == 'min' else min\r\n        self.fill = fill\r\n\r\n    def forward(self, img):\r\n        if isinstance(img, torch.Tensor):\r\n            height, width = img.shape[:2]\r\n        else:\r\n            width, height = img.size\r\n        scale = self.max_size / float(max(height, width))\r\n        if scale != 1.0:\r\n            new_size = tuple(round(dim * scale) for dim in (height, width))\r\n            img = F.resize(img, new_size, self.interpolation)\r\n            pad_h = self.max_size - new_size[0]\r\n            pad_w = self.max_size - new_size[1]\r\n            img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)\r\n        return img\r\n\r\n\r\ndef _convert_to_rgb(image):\r\n    return image.convert('RGB')\r\n\r\n\r\n# class CatGen(nn.Module):\r\n#     def __init__(self, num=4):\r\n#         self.num = num\r\n#     def mixgen_batch(image, text):\r\n#         batch_size = image.shape[0]\r\n#         index = np.random.permutation(batch_size)\r\n\r\n#         cat_images = []\r\n#         for i in range(batch_size):\r\n#             # image mixup\r\n#             image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]\r\n#             # text concat\r\n#             text[i] = tokenizer((str(text[i]) + \" \" + str(text[index[i]])))[0]\r\n#         text = torch.stack(text)\r\n#         return image, text\r\n\r\n\r\ndef image_transform(\r\n        image_size: int,\r\n        is_train: bool,\r\n        mean: Optional[Tuple[float, ...]] = None,\r\n        std: Optional[Tuple[float, ...]] = None,\r\n        resize_longest_max: bool = False,\r\n        fill_color: int = 0,\r\n):\r\n    mean = mean or OPENAI_DATASET_MEAN\r\n    if not isinstance(mean, (list, tuple)):\r\n        mean = (mean,) * 3\r\n\r\n    std = std or OPENAI_DATASET_STD\r\n    if not isinstance(std, (list, tuple)):\r\n        std = (std,) * 3\r\n\r\n    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:\r\n        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge\r\n        image_size = image_size[0]\r\n\r\n    normalize = Normalize(mean=mean, std=std)\r\n    if is_train:\r\n        return Compose([\r\n            RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),\r\n            _convert_to_rgb,\r\n            ToTensor(),\r\n            normalize,\r\n        ])\r\n    else:\r\n        if resize_longest_max:\r\n            transforms = [\r\n                ResizeMaxSize(image_size, fill=fill_color)\r\n            ]\r\n        else:\r\n            transforms = [\r\n                Resize(image_size, interpolation=InterpolationMode.BICUBIC),\r\n                CenterCrop(image_size),\r\n            ]\r\n        transforms.extend([\r\n            _convert_to_rgb,\r\n            ToTensor(),\r\n            normalize,\r\n        ])\r\n        return Compose(transforms)\r\n"
  },
  {
    "path": "eva_clip/transformer.py",
    "content": "import os\r\nimport logging\r\nfrom collections import OrderedDict\r\nimport math\r\nfrom typing import Callable, Optional, Sequence\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\ntry:\r\n    from timm.models.layers import trunc_normal_\r\nexcept:\r\n    from timm.layers import trunc_normal_\r\n    \r\nfrom .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast\r\nfrom .utils import to_2tuple\r\n\r\nif os.getenv('ENV_TYPE') == 'deepspeed':\r\n    try:\r\n        import deepspeed\r\n        from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint\r\n    except:\r\n        print(\"Please 'pip install deepspeed'\")\r\n        deepspeed = None\r\n        from torch.utils.checkpoint import checkpoint\r\nelse:\r\n    from torch.utils.checkpoint import checkpoint\r\n\r\ntry:\r\n    import xformers.ops as xops\r\nexcept ImportError:\r\n    xops = None\r\n    print(\"Please 'pip install xformers'\")\r\n\r\nclass LayerNormFp32(nn.LayerNorm):\r\n    \"\"\"Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).\"\"\"\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        output = F.layer_norm(\r\n            x.float(),\r\n            self.normalized_shape,\r\n            self.weight.float() if self.weight is not None else None,\r\n            self.bias.float() if self.bias is not None else None,\r\n            self.eps,\r\n        )\r\n        return output.type_as(x)\r\n\r\n\r\nclass LayerNorm(nn.LayerNorm):\r\n    \"\"\"Subclass torch's LayerNorm (with cast back to input dtype).\"\"\"\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        orig_type = x.dtype\r\n        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\r\n        return x.to(orig_type)\r\n\r\nclass QuickGELU(nn.Module):\r\n    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory\r\n    def forward(self, x: torch.Tensor):\r\n        return x * torch.sigmoid(1.702 * x)\r\n\r\n\r\nclass LayerScale(nn.Module):\r\n    def __init__(self, dim, init_values=1e-5, inplace=False):\r\n        super().__init__()\r\n        self.inplace = inplace\r\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\r\n\r\n    def forward(self, x):\r\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\r\n\r\nclass PatchDropout(nn.Module):\r\n    \"\"\"\r\n    https://arxiv.org/abs/2212.00794\r\n    \"\"\"\r\n\r\n    def __init__(self, prob, exclude_first_token=True):\r\n        super().__init__()\r\n        assert 0 <= prob < 1.\r\n        self.prob = prob\r\n        self.exclude_first_token = exclude_first_token  # exclude CLS token\r\n        logging.info(f\"os.getenv('RoPE')={os.getenv('RoPE')}\")\r\n\r\n    def forward(self, x):\r\n        if not self.training or self.prob == 0.:\r\n            return x\r\n\r\n        if self.exclude_first_token:\r\n            cls_tokens, x = x[:, :1], x[:, 1:]\r\n        else:\r\n            cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])\r\n\r\n        batch = x.size()[0]\r\n        num_tokens = x.size()[1]\r\n\r\n        batch_indices = torch.arange(batch)\r\n        batch_indices = batch_indices[..., None]\r\n\r\n        keep_prob = 1 - self.prob\r\n        num_patches_keep = max(1, int(num_tokens * keep_prob))\r\n\r\n        rand = torch.randn(batch, num_tokens)\r\n        patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices\r\n\r\n        x = x[batch_indices, patch_indices_keep]\r\n\r\n        if self.exclude_first_token:\r\n            x = torch.cat((cls_tokens, x), dim=1)\r\n\r\n        if self.training and os.getenv('RoPE') == '1':\r\n            return x, patch_indices_keep\r\n\r\n        return x\r\n\r\n\r\ndef _in_projection_packed(\r\n    q: torch.Tensor,\r\n    k: torch.Tensor,\r\n    v: torch.Tensor,\r\n    w: torch.Tensor,\r\n    b: Optional[torch.Tensor] = None,\r\n    ):\r\n    \"\"\"\r\n    https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726\r\n    \"\"\"\r\n    E = q.size(-1)\r\n    if k is v:\r\n        if q is k:\r\n            # self-attention\r\n            return F.linear(q, w, b).chunk(3, dim=-1)\r\n        else:\r\n            # encoder-decoder attention\r\n            w_q, w_kv = w.split([E, E * 2])\r\n            if b is None:\r\n                b_q = b_kv = None\r\n            else:\r\n                b_q, b_kv = b.split([E, E * 2])\r\n            return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)\r\n    else:\r\n        w_q, w_k, w_v = w.chunk(3)\r\n        if b is None:\r\n            b_q = b_k = b_v = None\r\n        else:\r\n            b_q, b_k, b_v = b.chunk(3)\r\n        return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)\r\n\r\nclass Attention(nn.Module):\r\n    def __init__(\r\n            self,\r\n            dim,\r\n            num_heads=8,\r\n            qkv_bias=True,\r\n            scaled_cosine=False,\r\n            scale_heads=False,\r\n            logit_scale_max=math.log(1. / 0.01),\r\n            attn_drop=0.,\r\n            proj_drop=0.,\r\n            xattn=False,\r\n            rope=False\r\n    ):\r\n        super().__init__()\r\n        self.scaled_cosine = scaled_cosine\r\n        self.scale_heads = scale_heads\r\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\r\n        self.num_heads = num_heads\r\n        self.head_dim = dim // num_heads\r\n        self.scale = self.head_dim ** -0.5\r\n        self.logit_scale_max = logit_scale_max\r\n\r\n        # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original\r\n        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)\r\n        if qkv_bias:\r\n            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))\r\n        else:\r\n            self.in_proj_bias = None\r\n\r\n        if self.scaled_cosine:\r\n            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\r\n        else:\r\n            self.logit_scale = None\r\n        self.attn_drop = nn.Dropout(attn_drop)\r\n        if self.scale_heads:\r\n            self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))\r\n        else:\r\n            self.head_scale = None\r\n        self.out_proj = nn.Linear(dim, dim)\r\n        self.out_drop = nn.Dropout(proj_drop)\r\n        self.xattn = xattn\r\n        self.xattn_drop = attn_drop\r\n        self.rope = rope\r\n\r\n    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):\r\n        L, N, C = x.shape\r\n        q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)\r\n        if self.xattn:\r\n            q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)\r\n            k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)\r\n            v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)\r\n\r\n            x = xops.memory_efficient_attention(\r\n                q, k, v,\r\n                p=self.xattn_drop,\r\n                scale=self.scale if self.logit_scale is None else None,\r\n                attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,\r\n                )\r\n        else:\r\n            q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)\r\n            k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)\r\n            v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)\r\n\r\n            if self.logit_scale is not None:\r\n                attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))\r\n                logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()\r\n                attn = attn.view(N, self.num_heads, L, L) * logit_scale\r\n                attn = attn.view(-1, L, L)\r\n            else:\r\n                q = q * self.scale\r\n                attn = torch.bmm(q, k.transpose(-1, -2))\r\n\r\n            if attn_mask is not None:\r\n                if attn_mask.dtype == torch.bool:\r\n                    new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)\r\n                    new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\r\n                    attn_mask = new_attn_mask\r\n                attn += attn_mask\r\n\r\n            attn = attn.softmax(dim=-1)\r\n            attn = self.attn_drop(attn)\r\n\r\n            x = torch.bmm(attn, v)\r\n\r\n        if self.head_scale is not None:\r\n            x = x.view(N, self.num_heads, L, C) * self.head_scale\r\n            x = x.view(-1, L, C)\r\n        x = x.transpose(0, 1).reshape(L, N, C)\r\n        x = self.out_proj(x)\r\n        x = self.out_drop(x)\r\n        return x\r\n\r\nclass CustomAttention(nn.Module):\r\n    def __init__(\r\n            self,\r\n            dim,\r\n            num_heads=8,\r\n            qkv_bias=True,\r\n            scaled_cosine=True,\r\n            scale_heads=False,\r\n            logit_scale_max=math.log(1. / 0.01),\r\n            attn_drop=0.,\r\n            proj_drop=0.,\r\n            xattn=False\r\n    ):\r\n        super().__init__()\r\n        self.scaled_cosine = scaled_cosine\r\n        self.scale_heads = scale_heads\r\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\r\n        self.num_heads = num_heads\r\n        self.head_dim = dim // num_heads\r\n        self.scale = self.head_dim ** -0.5\r\n        self.logit_scale_max = logit_scale_max\r\n\r\n        # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original\r\n        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)\r\n        if qkv_bias:\r\n            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))\r\n        else:\r\n            self.in_proj_bias = None\r\n\r\n        if self.scaled_cosine:\r\n            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\r\n        else:\r\n            self.logit_scale = None\r\n        self.attn_drop = nn.Dropout(attn_drop)\r\n        if self.scale_heads:\r\n            self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))\r\n        else:\r\n            self.head_scale = None\r\n        self.out_proj = nn.Linear(dim, dim)\r\n        self.out_drop = nn.Dropout(proj_drop)\r\n        self.xattn = xattn\r\n        self.xattn_drop = attn_drop\r\n\r\n    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\r\n        q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)\r\n        N_q, B_q, C_q = q.shape\r\n        N_k, B_k, C_k = k.shape\r\n        N_v, B_v, C_v = v.shape\r\n        if self.xattn:\r\n            # B, N, C -> B, N, num_heads, C\r\n            q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)\r\n            k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)\r\n            v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)\r\n\r\n            x = xops.memory_efficient_attention(\r\n                q, k, v,\r\n                p=self.xattn_drop,\r\n                scale=self.scale if self.logit_scale is None else None,\r\n                attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None\r\n                )\r\n        else:\r\n            # B*H, L, C\r\n            q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)\r\n            k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)\r\n            v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)\r\n\r\n            if self.logit_scale is not None:\r\n                # B*H, N_q, N_k\r\n                attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))\r\n                logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()\r\n                attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale\r\n                attn = attn.view(-1, N_q, N_k)\r\n            else:\r\n                q = q * self.scale\r\n                attn = torch.bmm(q, k.transpose(-1, -2))\r\n\r\n            if attn_mask is not None:\r\n                if attn_mask.dtype == torch.bool:\r\n                    new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)\r\n                    new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\r\n                    attn_mask = new_attn_mask\r\n                attn += attn_mask\r\n\r\n            attn = attn.softmax(dim=-1)\r\n            attn = self.attn_drop(attn)\r\n\r\n            x = torch.bmm(attn, v)\r\n            \r\n        if self.head_scale is not None:\r\n            x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale\r\n            x = x.view(-1, N_q, C_q)\r\n        x = x.transpose(0, 1).reshape(N_q, B_q, C_q)\r\n        x = self.out_proj(x)\r\n        x = self.out_drop(x)\r\n        return x\r\n\r\nclass CustomResidualAttentionBlock(nn.Module):\r\n    def __init__(\r\n            self,\r\n            d_model: int,\r\n            n_head: int,\r\n            mlp_ratio: float = 4.0,\r\n            ls_init_value: float = None,\r\n            act_layer: Callable = nn.GELU,\r\n            norm_layer: Callable = LayerNorm,\r\n            scale_cosine_attn: bool = False,\r\n            scale_heads: bool = False,\r\n            scale_attn: bool = False,\r\n            scale_fc: bool = False,\r\n            cross_attn: bool = False,\r\n            xattn: bool = False,\r\n    ):\r\n        super().__init__()\r\n\r\n        self.ln_1 = norm_layer(d_model)\r\n        self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1\r\n        self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1\r\n        self.attn = CustomAttention(\r\n            d_model, n_head,\r\n            qkv_bias=True,\r\n            attn_drop=0.,\r\n            proj_drop=0.,\r\n            scaled_cosine=scale_cosine_attn,\r\n            scale_heads=scale_heads,\r\n            xattn=xattn\r\n        )\r\n\r\n        self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()\r\n        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\r\n\r\n        self.ln_2 = norm_layer(d_model)\r\n        mlp_width = int(d_model * mlp_ratio)\r\n        self.mlp = nn.Sequential(OrderedDict([\r\n            (\"c_fc\", nn.Linear(d_model, mlp_width)),\r\n            ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),\r\n            (\"gelu\", act_layer()),\r\n            (\"c_proj\", nn.Linear(mlp_width, d_model))\r\n        ]))\r\n\r\n        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\r\n\r\n    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\r\n        q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))\r\n        q = q + self.ls_2(self.mlp(self.ln_2(q)))\r\n        return q\r\n\r\nclass CustomTransformer(nn.Module):\r\n    def __init__(\r\n            self,\r\n            width: int,\r\n            layers: int,\r\n            heads: int,\r\n            mlp_ratio: float = 4.0,\r\n            ls_init_value: float = None,\r\n            act_layer: Callable = nn.GELU,\r\n            norm_layer: Callable = LayerNorm,\r\n            scale_cosine_attn: bool = True,\r\n            scale_heads: bool = False,\r\n            scale_attn: bool = False,\r\n            scale_fc: bool = False,\r\n            cross_attn: bool = False,\r\n            xattn: bool = False,\r\n    ):\r\n        super().__init__()\r\n        self.width = width\r\n        self.layers = layers\r\n        self.grad_checkpointing = False\r\n        self.xattn = xattn\r\n\r\n        self.resblocks = nn.ModuleList([\r\n            CustomResidualAttentionBlock(\r\n                width,\r\n                heads,\r\n                mlp_ratio,\r\n                ls_init_value=ls_init_value,\r\n                act_layer=act_layer,\r\n                norm_layer=norm_layer,\r\n                scale_cosine_attn=scale_cosine_attn,\r\n                scale_heads=scale_heads,\r\n                scale_attn=scale_attn,\r\n                scale_fc=scale_fc,\r\n                cross_attn=cross_attn,\r\n                xattn=xattn)\r\n            for _ in range(layers)\r\n        ])\r\n\r\n    def get_cast_dtype(self) -> torch.dtype:\r\n        return self.resblocks[0].mlp.c_fc.weight.dtype \r\n\r\n    def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):\r\n        if k is None and v is None:\r\n            k = v = q\r\n        for r in self.resblocks:\r\n            if self.grad_checkpointing and not torch.jit.is_scripting():\r\n                q = checkpoint(r, q, k, v, attn_mask)\r\n            else:\r\n                q = r(q, k, v, attn_mask=attn_mask)\r\n        return q\r\n\r\n\r\nclass ResidualAttentionBlock(nn.Module):\r\n    def __init__(\r\n            self,\r\n            d_model: int,\r\n            n_head: int,\r\n            mlp_ratio: float = 4.0,\r\n            ls_init_value: float = None,\r\n            act_layer: Callable = nn.GELU,\r\n            norm_layer: Callable = LayerNorm,\r\n            xattn: bool = False,\r\n    ):\r\n        super().__init__()\r\n\r\n        self.ln_1 = norm_layer(d_model)\r\n        if xattn:\r\n            self.attn = Attention(d_model, n_head, xattn=True)\r\n        else:\r\n            self.attn = nn.MultiheadAttention(d_model, n_head)\r\n        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\r\n\r\n        self.ln_2 = norm_layer(d_model)\r\n        mlp_width = int(d_model * mlp_ratio)\r\n        self.mlp = nn.Sequential(OrderedDict([\r\n            (\"c_fc\", nn.Linear(d_model, mlp_width)),\r\n            (\"gelu\", act_layer()),\r\n            (\"c_proj\", nn.Linear(mlp_width, d_model))\r\n        ]))\r\n\r\n        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\r\n        self.xattn = xattn\r\n\r\n    def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\r\n        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None\r\n        if self.xattn:\r\n            return self.attn(x, attn_mask=attn_mask)\r\n        return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]\r\n\r\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\r\n        x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))\r\n        x = x + self.ls_2(self.mlp(self.ln_2(x)))\r\n        return x\r\n\r\nclass Transformer(nn.Module):\r\n    def __init__(\r\n            self,\r\n            width: int,\r\n            layers: int,\r\n            heads: int,\r\n            mlp_ratio: float = 4.0,\r\n            ls_init_value: float = None,\r\n            act_layer: Callable = nn.GELU,\r\n            norm_layer: Callable = LayerNorm,\r\n            xattn: bool = False,\r\n    ):\r\n        super().__init__()\r\n        self.width = width\r\n        self.layers = layers\r\n        self.grad_checkpointing = False\r\n\r\n        self.resblocks = nn.ModuleList([\r\n            ResidualAttentionBlock(\r\n                width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)\r\n            for _ in range(layers)\r\n        ])\r\n\r\n    def get_cast_dtype(self) -> torch.dtype:\r\n        return self.resblocks[0].mlp.c_fc.weight.dtype\r\n\r\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\r\n        for r in self.resblocks:\r\n            if self.grad_checkpointing and not torch.jit.is_scripting():\r\n                x = checkpoint(r, x, attn_mask)\r\n            else:\r\n                x = r(x, attn_mask=attn_mask)\r\n        return x\r\n\r\n\r\nclass VisionTransformer(nn.Module):\r\n    def __init__(\r\n            self,\r\n            image_size: int,\r\n            patch_size: int,\r\n            width: int,\r\n            layers: int,\r\n            heads: int,\r\n            mlp_ratio: float,\r\n            ls_init_value: float = None,\r\n            patch_dropout: float = 0.,\r\n            global_average_pool: bool = False,\r\n            output_dim: int = 512,\r\n            act_layer: Callable = nn.GELU,\r\n            norm_layer: Callable = LayerNorm,\r\n            xattn: bool = False,\r\n    ):\r\n        super().__init__()\r\n        self.image_size = to_2tuple(image_size)\r\n        self.patch_size = to_2tuple(patch_size)\r\n        self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])\r\n        self.output_dim = output_dim\r\n        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)\r\n\r\n        scale = width ** -0.5\r\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\r\n        self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))\r\n\r\n        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\r\n        self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()\r\n        self.ln_pre = norm_layer(width)\r\n        \r\n        self.transformer = Transformer(\r\n            width,\r\n            layers,\r\n            heads,\r\n            mlp_ratio,\r\n            ls_init_value=ls_init_value,\r\n            act_layer=act_layer,\r\n            norm_layer=norm_layer,\r\n            xattn=xattn\r\n        )\r\n\r\n        self.global_average_pool = global_average_pool\r\n        self.ln_post = norm_layer(width)\r\n        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))\r\n\r\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\r\n        for param in self.parameters():\r\n            param.requires_grad = False\r\n        \r\n        if unlocked_groups != 0:\r\n            groups = [\r\n                [\r\n                    self.conv1,\r\n                    self.class_embedding,\r\n                    self.positional_embedding,\r\n                    self.ln_pre,\r\n                ],\r\n                *self.transformer.resblocks[:-1],\r\n                [\r\n                    self.transformer.resblocks[-1],\r\n                    self.ln_post,\r\n                ],\r\n                self.proj,\r\n            ]\r\n\r\n            def _unlock(x):\r\n                if isinstance(x, Sequence):\r\n                    for g in x:\r\n                        _unlock(g)\r\n                else:\r\n                    if isinstance(x, torch.nn.Parameter):\r\n                        x.requires_grad = True\r\n                    else:\r\n                        for p in x.parameters():\r\n                            p.requires_grad = True\r\n\r\n            _unlock(groups[-unlocked_groups:])\r\n\r\n    def get_num_layers(self):\r\n        return self.transformer.layers\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        self.transformer.grad_checkpointing = enable\r\n\r\n    @torch.jit.ignore\r\n    def no_weight_decay(self):\r\n        return {'positional_embedding', 'class_embedding'}\r\n\r\n    def forward(self, x: torch.Tensor, return_all_features: bool=False):\r\n        x = self.conv1(x)  # shape = [*, width, grid, grid]\r\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\r\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\r\n        x = torch.cat(\r\n            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),\r\n             x], dim=1)  # shape = [*, grid ** 2 + 1, width]\r\n        x = x + self.positional_embedding.to(x.dtype)\r\n\r\n        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in\r\n        x = self.patch_dropout(x)\r\n        x = self.ln_pre(x)\r\n\r\n        x = x.permute(1, 0, 2)  # NLD -> LND\r\n        x = self.transformer(x)\r\n        x = x.permute(1, 0, 2)  # LND -> NLD\r\n\r\n        if not return_all_features:\r\n            if self.global_average_pool:\r\n                x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)\r\n            else:\r\n                x = x[:, 0]\r\n\r\n            x = self.ln_post(x)\r\n\r\n            if self.proj is not None:\r\n                x = x @ self.proj\r\n\r\n        return x\r\n\r\n\r\nclass TextTransformer(nn.Module):\r\n    def __init__(\r\n            self,\r\n            context_length: int = 77,\r\n            vocab_size: int = 49408,\r\n            width: int = 512,\r\n            heads: int = 8,\r\n            layers: int = 12,\r\n            ls_init_value: float = None,\r\n            output_dim: int = 512,\r\n            act_layer: Callable = nn.GELU,\r\n            norm_layer: Callable = LayerNorm,\r\n            xattn: bool= False,\r\n            attn_mask: bool = True\r\n    ):\r\n        super().__init__()\r\n        self.context_length = context_length\r\n        self.vocab_size = vocab_size\r\n        self.width = width\r\n        self.output_dim = output_dim\r\n\r\n        self.token_embedding = nn.Embedding(vocab_size, width)\r\n        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))\r\n        self.transformer = Transformer(\r\n            width=width,\r\n            layers=layers,\r\n            heads=heads,\r\n            ls_init_value=ls_init_value,\r\n            act_layer=act_layer,\r\n            norm_layer=norm_layer,\r\n            xattn=xattn\r\n        )\r\n        \r\n        self.xattn = xattn\r\n        self.ln_final = norm_layer(width)\r\n        self.text_projection = nn.Parameter(torch.empty(width, output_dim))\r\n\r\n        if attn_mask:\r\n            self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)\r\n        else:\r\n            self.attn_mask = None\r\n\r\n        self.init_parameters()\r\n\r\n    def init_parameters(self):\r\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\r\n        nn.init.normal_(self.positional_embedding, std=0.01)\r\n\r\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\r\n        attn_std = self.transformer.width ** -0.5\r\n        fc_std = (2 * self.transformer.width) ** -0.5\r\n        for block in self.transformer.resblocks:\r\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\r\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\r\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\r\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\r\n\r\n        if self.text_projection is not None:\r\n            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\r\n\r\n    @torch.jit.ignore\r\n    def set_grad_checkpointing(self, enable=True):\r\n        self.transformer.grad_checkpointing = enable\r\n    \r\n    @torch.jit.ignore\r\n    def no_weight_decay(self):\r\n        # return {'positional_embedding', 'token_embedding'}\r\n        return {'positional_embedding'}\r\n\r\n    def get_num_layers(self):\r\n        return self.transformer.layers\r\n\r\n    def build_attention_mask(self):\r\n        # lazily create causal attention mask, with full attention between the vision tokens\r\n        # pytorch uses additive attention mask; fill with -inf\r\n        mask = torch.empty(self.context_length, self.context_length)\r\n        mask.fill_(float(\"-inf\"))\r\n        mask.triu_(1)  # zero out the lower diagonal\r\n        return mask\r\n\r\n    def forward(self, text, return_all_features: bool=False):\r\n        cast_dtype = self.transformer.get_cast_dtype()\r\n        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]\r\n\r\n        x = x + self.positional_embedding.to(cast_dtype)\r\n        x = x.permute(1, 0, 2)  # NLD -> LND\r\n        x = self.transformer(x, attn_mask=self.attn_mask)\r\n        # x = self.transformer(x) # no attention mask is applied\r\n        x = x.permute(1, 0, 2)  # LND -> NLD\r\n        x = self.ln_final(x)\r\n\r\n        if not return_all_features:\r\n            # x.shape = [batch_size, n_ctx, transformer.width]\r\n            # take features from the eot embedding (eot_token is the highest number in each sequence)\r\n            x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\r\n        return x\r\n"
  },
  {
    "path": "eva_clip/utils.py",
    "content": "from itertools import repeat\r\nimport collections.abc\r\nimport logging\r\nimport math\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom torch import nn as nn\r\nfrom torchvision.ops.misc import FrozenBatchNorm2d\r\nimport torch.nn.functional as F\r\n\r\n# open CLIP\r\ndef resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):\r\n    # Rescale the grid of position embeddings when loading from state_dict\r\n    old_pos_embed = state_dict.get('visual.positional_embedding', None)\r\n    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):\r\n        return\r\n    grid_size = to_2tuple(model.visual.grid_size)\r\n    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)\r\n    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens\r\n    if new_seq_len == old_pos_embed.shape[0]:\r\n        return\r\n\r\n    if extra_tokens:\r\n        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]\r\n    else:\r\n        pos_emb_tok, pos_emb_img = None, old_pos_embed\r\n    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))\r\n\r\n    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)\r\n    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)\r\n    pos_emb_img = F.interpolate(\r\n        pos_emb_img,\r\n        size=grid_size,\r\n        mode=interpolation,\r\n        align_corners=True,\r\n    )\r\n    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]\r\n    if pos_emb_tok is not None:\r\n        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)\r\n    else:\r\n        new_pos_embed = pos_emb_img\r\n    state_dict['visual.positional_embedding'] = new_pos_embed\r\n\r\n\r\ndef resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):\r\n    # Rescale the grid of position embeddings when loading from state_dict\r\n    old_pos_embed = state_dict.get('positional_embedding', None)\r\n    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):\r\n        return\r\n    grid_size = to_2tuple(model.visual.grid_size)\r\n    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)\r\n    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens\r\n    if new_seq_len == old_pos_embed.shape[0]:\r\n        return\r\n\r\n    if extra_tokens:\r\n        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]\r\n    else:\r\n        pos_emb_tok, pos_emb_img = None, old_pos_embed\r\n    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))\r\n\r\n    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)\r\n    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)\r\n    pos_emb_img = F.interpolate(\r\n        pos_emb_img,\r\n        size=grid_size,\r\n        mode=interpolation,\r\n        align_corners=True,\r\n    )\r\n    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]\r\n    if pos_emb_tok is not None:\r\n        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)\r\n    else:\r\n        new_pos_embed = pos_emb_img\r\n    state_dict['positional_embedding'] = new_pos_embed\r\n\r\ndef resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):\r\n    all_keys = list(state_dict.keys())\r\n    # interpolate position embedding\r\n    if 'visual.pos_embed' in state_dict:\r\n        pos_embed_checkpoint = state_dict['visual.pos_embed']\r\n        embedding_size = pos_embed_checkpoint.shape[-1]\r\n        num_patches = model.visual.patch_embed.num_patches\r\n        num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches\r\n        # height (== width) for the checkpoint position embedding\r\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\r\n        # height (== width) for the new position embedding\r\n        new_size = int(num_patches ** 0.5)\r\n        # class_token and dist_token are kept unchanged\r\n        if orig_size != new_size:\r\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\r\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\r\n            # only the position tokens are interpolated\r\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\r\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\r\n            pos_tokens = torch.nn.functional.interpolate(\r\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\r\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\r\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\r\n            state_dict['visual.pos_embed'] = new_pos_embed\r\n\r\n            patch_embed_proj = state_dict['visual.patch_embed.proj.weight']\r\n            patch_size = model.visual.patch_embed.patch_size\r\n            state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(\r\n                patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)\r\n\r\n\r\ndef resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):\r\n    all_keys = list(state_dict.keys())\r\n    # interpolate position embedding\r\n    if 'pos_embed' in state_dict:\r\n        pos_embed_checkpoint = state_dict['pos_embed']\r\n        embedding_size = pos_embed_checkpoint.shape[-1]\r\n        num_patches = model.visual.patch_embed.num_patches\r\n        num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches\r\n        # height (== width) for the checkpoint position embedding\r\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\r\n        # height (== width) for the new position embedding\r\n        new_size = int(num_patches ** 0.5)\r\n        # class_token and dist_token are kept unchanged\r\n        if orig_size != new_size:\r\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\r\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\r\n            # only the position tokens are interpolated\r\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\r\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\r\n            pos_tokens = torch.nn.functional.interpolate(\r\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\r\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\r\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\r\n            state_dict['pos_embed'] = new_pos_embed\r\n\r\n            patch_embed_proj = state_dict['patch_embed.proj.weight']\r\n            patch_size = model.visual.patch_embed.patch_size\r\n            state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(\r\n                patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)\r\n                \r\n\r\ndef resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):\r\n    all_keys = list(state_dict.keys())\r\n    for key in all_keys:\r\n        if \"relative_position_index\" in key:\r\n            state_dict.pop(key)\r\n\r\n        if \"relative_position_bias_table\" in key:\r\n            rel_pos_bias = state_dict[key]\r\n            src_num_pos, num_attn_heads = rel_pos_bias.size()\r\n            dst_num_pos, _ = model.visual.state_dict()[key].size()\r\n            dst_patch_shape = model.visual.patch_embed.patch_shape\r\n            if dst_patch_shape[0] != dst_patch_shape[1]:\r\n                raise NotImplementedError()\r\n            num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)\r\n            src_size = int((src_num_pos - num_extra_tokens) ** 0.5)\r\n            dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)\r\n            if src_size != dst_size:\r\n                print(\"Position interpolate for %s from %dx%d to %dx%d\" % (\r\n                    key, src_size, src_size, dst_size, dst_size))\r\n                extra_tokens = rel_pos_bias[-num_extra_tokens:, :]\r\n                rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]\r\n\r\n                def geometric_progression(a, r, n):\r\n                    return a * (1.0 - r ** n) / (1.0 - r)\r\n\r\n                left, right = 1.01, 1.5\r\n                while right - left > 1e-6:\r\n                    q = (left + right) / 2.0\r\n                    gp = geometric_progression(1, q, src_size // 2)\r\n                    if gp > dst_size // 2:\r\n                        right = q\r\n                    else:\r\n                        left = q\r\n\r\n                # if q > 1.090307:\r\n                #     q = 1.090307\r\n\r\n                dis = []\r\n                cur = 1\r\n                for i in range(src_size // 2):\r\n                    dis.append(cur)\r\n                    cur += q ** (i + 1)\r\n\r\n                r_ids = [-_ for _ in reversed(dis)]\r\n\r\n                x = r_ids + [0] + dis\r\n                y = r_ids + [0] + dis\r\n\r\n                t = dst_size // 2.0\r\n                dx = np.arange(-t, t + 0.1, 1.0)\r\n                dy = np.arange(-t, t + 0.1, 1.0)\r\n\r\n                print(\"Original positions = %s\" % str(x))\r\n                print(\"Target positions = %s\" % str(dx))\r\n\r\n                all_rel_pos_bias = []\r\n\r\n                for i in range(num_attn_heads):\r\n                    z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()\r\n                    f = F.interpolate.interp2d(x, y, z, kind='cubic')\r\n                    all_rel_pos_bias.append(\r\n                        torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))\r\n\r\n                rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\r\n\r\n                new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)\r\n                state_dict[key] = new_rel_pos_bias\r\n\r\n    # interpolate position embedding\r\n    if 'pos_embed' in state_dict:\r\n        pos_embed_checkpoint = state_dict['pos_embed']\r\n        embedding_size = pos_embed_checkpoint.shape[-1]\r\n        num_patches = model.visual.patch_embed.num_patches\r\n        num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches\r\n        # height (== width) for the checkpoint position embedding\r\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\r\n        # height (== width) for the new position embedding\r\n        new_size = int(num_patches ** 0.5)\r\n        # class_token and dist_token are kept unchanged\r\n        if orig_size != new_size:\r\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\r\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\r\n            # only the position tokens are interpolated\r\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\r\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\r\n            pos_tokens = torch.nn.functional.interpolate(\r\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\r\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\r\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\r\n            state_dict['pos_embed'] = new_pos_embed\r\n\r\n            patch_embed_proj = state_dict['patch_embed.proj.weight']\r\n            patch_size = model.visual.patch_embed.patch_size\r\n            state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(\r\n                patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)\r\n\r\n\r\ndef freeze_batch_norm_2d(module, module_match={}, name=''):\r\n    \"\"\"\r\n    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is\r\n    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and\r\n    returned. Otherwise, the module is walked recursively and submodules are converted in place.\r\n\r\n    Args:\r\n        module (torch.nn.Module): Any PyTorch module.\r\n        module_match (dict): Dictionary of full module names to freeze (all if empty)\r\n        name (str): Full module name (prefix)\r\n\r\n    Returns:\r\n        torch.nn.Module: Resulting module\r\n\r\n    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762\r\n    \"\"\"\r\n    res = module\r\n    is_match = True\r\n    if module_match:\r\n        is_match = name in module_match\r\n    if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):\r\n        res = FrozenBatchNorm2d(module.num_features)\r\n        res.num_features = module.num_features\r\n        res.affine = module.affine\r\n        if module.affine:\r\n            res.weight.data = module.weight.data.clone().detach()\r\n            res.bias.data = module.bias.data.clone().detach()\r\n        res.running_mean.data = module.running_mean.data\r\n        res.running_var.data = module.running_var.data\r\n        res.eps = module.eps\r\n    else:\r\n        for child_name, child in module.named_children():\r\n            full_child_name = '.'.join([name, child_name]) if name else child_name\r\n            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)\r\n            if new_child is not child:\r\n                res.add_module(child_name, new_child)\r\n    return res\r\n\r\n\r\n# From PyTorch internals\r\ndef _ntuple(n):\r\n    def parse(x):\r\n        if isinstance(x, collections.abc.Iterable):\r\n            return x\r\n        return tuple(repeat(x, n))\r\n    return parse\r\n\r\n\r\nto_1tuple = _ntuple(1)\r\nto_2tuple = _ntuple(2)\r\nto_3tuple = _ntuple(3)\r\nto_4tuple = _ntuple(4)\r\nto_ntuple = lambda n, x: _ntuple(n)(x)\r\n\r\n\r\ndef is_logging(args):\r\n    def is_global_master(args):\r\n        return args.rank == 0\r\n\r\n    def is_local_master(args):\r\n        return args.local_rank == 0\r\n\r\n    def is_master(args, local=False):\r\n        return is_local_master(args) if local else is_global_master(args)\r\n    return is_master\r\n\r\n\r\nclass AllGather(torch.autograd.Function):\r\n    \"\"\"An autograd function that performs allgather on a tensor.\r\n    Performs all_gather operation on the provided tensors.\r\n    *** Warning ***: torch.distributed.all_gather has no gradient.\r\n    \"\"\"\r\n\r\n    @staticmethod\r\n    def forward(ctx, tensor, rank, world_size):\r\n        tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]\r\n        torch.distributed.all_gather(tensors_gather, tensor)\r\n        ctx.rank = rank\r\n        ctx.batch_size = tensor.shape[0]\r\n        return torch.cat(tensors_gather, 0)\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        return (\r\n            grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],\r\n            None,\r\n            None\r\n        )\r\n\r\nallgather = AllGather.apply"
  },
  {
    "path": "examples/flux_pulid_multi.json",
    "content": "{\n  \"last_node_id\": 66,\n  \"last_link_id\": 133,\n  \"nodes\": [\n    {\n      \"id\": 16,\n      \"type\": \"KSamplerSelect\",\n      \"pos\": {\n        \"0\": 384,\n        \"1\": 313\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 58\n      },\n      \"flags\": {},\n      \"order\": 0,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"SAMPLER\",\n          \"type\": \"SAMPLER\",\n          \"links\": [\n            85\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"KSamplerSelect\"\n      },\n      \"widgets_values\": [\n        \"euler\"\n      ]\n    },\n    {\n      \"id\": 10,\n      \"type\": \"VAELoader\",\n      \"pos\": {\n        \"0\": 12,\n        \"1\": 285\n      },\n      \"size\": {\n        \"0\": 311.81634521484375,\n        \"1\": 60.429901123046875\n      },\n      \"flags\": {},\n      \"order\": 1,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"VAE\",\n          \"type\": \"VAE\",\n          \"links\": [\n            88\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"VAELoader\"\n      },\n      \"widgets_values\": [\n        \"ae.sft\"\n      ]\n    },\n    {\n      \"id\": 27,\n      \"type\": \"EmptySD3LatentImage\",\n      \"pos\": {\n        \"0\": 383,\n        \"1\": 155\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 106\n      },\n      \"flags\": {},\n      \"order\": 2,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"LATENT\",\n          \"type\": \"LATENT\",\n          \"links\": [\n            86\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"EmptySD3LatentImage\"\n      },\n      \"widgets_values\": [\n        896,\n        1152,\n        1\n      ]\n    },\n    {\n      \"id\": 25,\n      \"type\": \"RandomNoise\",\n      \"pos\": {\n        \"0\": 6,\n        \"1\": -135\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 82\n      },\n      \"flags\": {},\n      \"order\": 3,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"NOISE\",\n          \"type\": \"NOISE\",\n          \"links\": [\n            84\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"RandomNoise\"\n      },\n      \"widgets_values\": [\n        641817409332707,\n        \"randomize\"\n      ]\n    },\n    {\n      \"id\": 47,\n      \"type\": \"BasicGuider\",\n      \"pos\": {\n        \"0\": 1088,\n        \"1\": 366\n      },\n      \"size\": {\n        \"0\": 241.79998779296875,\n        \"1\": 46\n      },\n      \"flags\": {},\n      \"order\": 15,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"model\",\n          \"type\": \"MODEL\",\n          \"link\": 122\n        },\n        {\n          \"name\": \"conditioning\",\n          \"type\": \"CONDITIONING\",\n          \"link\": 107\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"GUIDER\",\n          \"type\": \"GUIDER\",\n          \"links\": [\n            83\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"BasicGuider\"\n      }\n    },\n    {\n      \"id\": 49,\n      \"type\": \"VAEDecode\",\n      \"pos\": {\n        \"0\": 1168,\n        \"1\": -111\n      },\n      \"size\": {\n        \"0\": 210,\n        \"1\": 46\n      },\n      \"flags\": {},\n      \"order\": 17,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"samples\",\n          \"type\": \"LATENT\",\n          \"link\": 87\n        },\n        {\n          \"name\": \"vae\",\n          \"type\": \"VAE\",\n          \"link\": 88\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"IMAGE\",\n          \"type\": \"IMAGE\",\n          \"links\": [\n            89\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"VAEDecode\"\n      }\n    },\n    {\n      \"id\": 50,\n      \"type\": \"PreviewImage\",\n      \"pos\": {\n        \"0\": 1502,\n        \"1\": -451\n      },\n      \"size\": {\n        \"0\": 1079.977783203125,\n        \"1\": 1041.9154052734375\n      },\n      \"flags\": {},\n      \"order\": 18,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"images\",\n          \"type\": \"IMAGE\",\n          \"link\": 89\n        }\n      ],\n      \"outputs\": [],\n      \"properties\": {\n        \"Node name for S&R\": \"PreviewImage\"\n      }\n    },\n    {\n      \"id\": 63,\n      \"type\": \"UNETLoader\",\n      \"pos\": {\n        \"0\": 6,\n        \"1\": -7\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 82\n      },\n      \"flags\": {},\n      \"order\": 4,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"MODEL\",\n          \"type\": \"MODEL\",\n          \"links\": [\n            130,\n            131\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"UNETLoader\"\n      },\n      \"widgets_values\": [\n        \"flux1-dev.safetensors\",\n        \"default\"\n      ]\n    },\n    {\n      \"id\": 17,\n      \"type\": \"BasicScheduler\",\n      \"pos\": {\n        \"0\": 392,\n        \"1\": 424\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 106\n      },\n      \"flags\": {\n        \"collapsed\": false\n      },\n      \"order\": 11,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"model\",\n          \"type\": \"MODEL\",\n          \"link\": 131,\n          \"slot_index\": 0\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"SIGMAS\",\n          \"type\": \"SIGMAS\",\n          \"links\": [\n            93\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"BasicScheduler\"\n      },\n      \"widgets_values\": [\n        \"simple\",\n        20,\n        1\n      ]\n    },\n    {\n      \"id\": 6,\n      \"type\": \"CLIPTextEncode\",\n      \"pos\": {\n        \"0\": 369,\n        \"1\": -63\n      },\n      \"size\": {\n        \"0\": 422.84503173828125,\n        \"1\": 164.31304931640625\n      },\n      \"flags\": {},\n      \"order\": 12,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"clip\",\n          \"type\": \"CLIP\",\n          \"link\": 132\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"CONDITIONING\",\n          \"type\": \"CONDITIONING\",\n          \"links\": [\n            41\n          ],\n          \"slot_index\": 0\n        }\n      ],\n      \"title\": \"CLIP Text Encode (Positive Prompt)\",\n      \"properties\": {\n        \"Node name for S&R\": \"CLIPTextEncode\"\n      },\n      \"widgets_values\": [\n        \"portrait, color, cinematic\"\n      ]\n    },\n    {\n      \"id\": 48,\n      \"type\": \"SamplerCustomAdvanced\",\n      \"pos\": {\n        \"0\": 1128,\n        \"1\": -12\n      },\n      \"size\": {\n        \"0\": 355.20001220703125,\n        \"1\": 326\n      },\n      \"flags\": {},\n      \"order\": 16,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"noise\",\n          \"type\": \"NOISE\",\n          \"link\": 84\n        },\n        {\n          \"name\": \"guider\",\n          \"type\": \"GUIDER\",\n          \"link\": 83\n        },\n        {\n          \"name\": \"sampler\",\n          \"type\": \"SAMPLER\",\n          \"link\": 85\n        },\n        {\n          \"name\": \"sigmas\",\n          \"type\": \"SIGMAS\",\n          \"link\": 93\n        },\n        {\n          \"name\": \"latent_image\",\n          \"type\": \"LATENT\",\n          \"link\": 86\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"output\",\n          \"type\": \"LATENT\",\n          \"links\": [\n            87\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        },\n        {\n          \"name\": \"denoised_output\",\n          \"type\": \"LATENT\",\n          \"links\": null,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"SamplerCustomAdvanced\"\n      }\n    },\n    {\n      \"id\": 53,\n      \"type\": \"PulidFluxInsightFaceLoader\",\n      \"pos\": {\n        \"0\": 799,\n        \"1\": -172\n      },\n      \"size\": {\n        \"0\": 365.4000244140625,\n        \"1\": 58\n      },\n      \"flags\": {},\n      \"order\": 5,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"FACEANALYSIS\",\n          \"type\": \"FACEANALYSIS\",\n          \"links\": [\n            124\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"PulidFluxInsightFaceLoader\"\n      },\n      \"widgets_values\": [\n        \"CPU\"\n      ]\n    },\n    {\n      \"id\": 26,\n      \"type\": \"FluxGuidance\",\n      \"pos\": {\n        \"0\": 372,\n        \"1\": -171\n      },\n      \"size\": {\n        \"0\": 317.4000244140625,\n        \"1\": 58\n      },\n      \"flags\": {\n        \"collapsed\": false\n      },\n      \"order\": 14,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"conditioning\",\n          \"type\": \"CONDITIONING\",\n          \"link\": 41\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"CONDITIONING\",\n          \"type\": \"CONDITIONING\",\n          \"links\": [\n            107\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"FluxGuidance\"\n      },\n      \"widgets_values\": [\n        4\n      ]\n    },\n    {\n      \"id\": 64,\n      \"type\": \"DualCLIPLoader\",\n      \"pos\": {\n        \"0\": 8,\n        \"1\": 124\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 106\n      },\n      \"flags\": {},\n      \"order\": 6,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"CLIP\",\n          \"type\": \"CLIP\",\n          \"links\": [\n            132\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"DualCLIPLoader\"\n      },\n      \"widgets_values\": [\n        \"t5xxl_fp8_e4m3fn.safetensors\",\n        \"clip_l.safetensors\",\n        \"flux\"\n      ]\n    },\n    {\n      \"id\": 66,\n      \"type\": \"LoadImagesFromDir //Inspire\",\n      \"pos\": {\n        \"0\": 14,\n        \"1\": 623\n      },\n      \"size\": {\n        \"0\": 567,\n        \"1\": 170\n      },\n      \"flags\": {},\n      \"order\": 7,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"IMAGE\",\n          \"type\": \"IMAGE\",\n          \"links\": [\n            133\n          ],\n          \"shape\": 3,\n          \"slot_index\": 0\n        },\n        {\n          \"name\": \"MASK\",\n          \"type\": \"MASK\",\n          \"links\": null,\n          \"shape\": 3\n        },\n        {\n          \"name\": \"INT\",\n          \"type\": \"INT\",\n          \"links\": null,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"LoadImagesFromDir //Inspire\"\n      },\n      \"widgets_values\": [\n        \"\",\n        0,\n        0,\n        false\n      ]\n    },\n    {\n      \"id\": 45,\n      \"type\": \"PulidFluxModelLoader\",\n      \"pos\": {\n        \"0\": 788,\n        \"1\": 42\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 58\n      },\n      \"flags\": {},\n      \"order\": 8,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"PULIDFLUX\",\n          \"type\": \"PULIDFLUX\",\n          \"links\": [\n            125\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"PulidFluxModelLoader\"\n      },\n      \"widgets_values\": [\n        \"pulid_flux_v0.9.0.safetensors\"\n      ]\n    },\n    {\n      \"id\": 51,\n      \"type\": \"PulidFluxEvaClipLoader\",\n      \"pos\": {\n        \"0\": 799,\n        \"1\": -60\n      },\n      \"size\": {\n        \"0\": 327.5999755859375,\n        \"1\": 26\n      },\n      \"flags\": {},\n      \"order\": 9,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"EVA_CLIP\",\n          \"type\": \"EVA_CLIP\",\n          \"links\": [\n            123\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"PulidFluxEvaClipLoader\"\n      }\n    },\n    {\n      \"id\": 65,\n      \"type\": \"Note\",\n      \"pos\": {\n        \"0\": 797,\n        \"1\": 565\n      },\n      \"size\": {\n        \"0\": 278.80340576171875,\n        \"1\": 167.5153045654297\n      },\n      \"flags\": {},\n      \"order\": 10,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [],\n      \"properties\": {},\n      \"widgets_values\": [\n        \"fusion_weight_max and min only works when choose auto_weight.\\n\\ntrain_step only works when choose train_weight\"\n      ]\n    },\n    {\n      \"id\": 62,\n      \"type\": \"ApplyPulidFlux\",\n      \"pos\": {\n        \"0\": 740,\n        \"1\": 174\n      },\n      \"size\": {\n        \"0\": 315,\n        \"1\": 326\n      },\n      \"flags\": {},\n      \"order\": 13,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"model\",\n          \"type\": \"MODEL\",\n          \"link\": 130\n        },\n        {\n          \"name\": \"pulid_flux\",\n          \"type\": \"PULIDFLUX\",\n          \"link\": 125\n        },\n        {\n          \"name\": \"eva_clip\",\n          \"type\": \"EVA_CLIP\",\n          \"link\": 123\n        },\n        {\n          \"name\": \"face_analysis\",\n          \"type\": \"FACEANALYSIS\",\n          \"link\": 124\n        },\n        {\n          \"name\": \"image\",\n          \"type\": \"IMAGE\",\n          \"link\": 133\n        },\n        {\n          \"name\": \"attn_mask\",\n          \"type\": \"MASK\",\n          \"link\": null\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"MODEL\",\n          \"type\": \"MODEL\",\n          \"links\": [\n            122\n          ],\n          \"slot_index\": 0,\n          \"shape\": 3\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"ApplyPulidFlux\"\n      },\n      \"widgets_values\": [\n        1,\n        0,\n        1,\n        \"mean\",\n        1,\n        0,\n        1000,\n        true\n      ]\n    }\n  ],\n  \"links\": [\n    [\n      41,\n      6,\n      0,\n      26,\n      0,\n      \"CONDITIONING\"\n    ],\n    [\n      83,\n      47,\n      0,\n      48,\n      1,\n      \"GUIDER\"\n    ],\n    [\n      84,\n      25,\n      0,\n      48,\n      0,\n      \"NOISE\"\n    ],\n    [\n      85,\n      16,\n      0,\n      48,\n      2,\n      \"SAMPLER\"\n    ],\n    [\n      86,\n      27,\n      0,\n      48,\n      4,\n      \"LATENT\"\n    ],\n    [\n      87,\n      48,\n      0,\n      49,\n      0,\n      \"LATENT\"\n    ],\n    [\n      88,\n      10,\n      0,\n      49,\n      1,\n      \"VAE\"\n    ],\n    [\n      89,\n      49,\n      0,\n      50,\n      0,\n      \"IMAGE\"\n    ],\n    [\n      93,\n      17,\n      0,\n      48,\n      3,\n      \"SIGMAS\"\n    ],\n    [\n      107,\n      26,\n      0,\n      47,\n      1,\n      \"CONDITIONING\"\n    ],\n    [\n      122,\n      62,\n      0,\n      47,\n      0,\n      \"MODEL\"\n    ],\n    [\n      123,\n      51,\n      0,\n      62,\n      2,\n      \"EVA_CLIP\"\n    ],\n    [\n      124,\n      53,\n      0,\n      62,\n      3,\n      \"FACEANALYSIS\"\n    ],\n    [\n      125,\n      45,\n      0,\n      62,\n      1,\n      \"PULIDFLUX\"\n    ],\n    [\n      130,\n      63,\n      0,\n      62,\n      0,\n      \"MODEL\"\n    ],\n    [\n      131,\n      63,\n      0,\n      17,\n      0,\n      \"MODEL\"\n    ],\n    [\n      132,\n      64,\n      0,\n      6,\n      0,\n      \"CLIP\"\n    ],\n    [\n      133,\n      66,\n      0,\n      62,\n      4,\n      \"IMAGE\"\n    ]\n  ],\n  \"groups\": [],\n  \"config\": {},\n  \"extra\": {\n    \"ds\": {\n      \"scale\": 0.6830134553650705,\n      \"offset\": [\n        237.9025120377926,\n        565.1585643260208\n      ]\n    }\n  },\n  \"version\": 0.4\n}\n"
  },
  {
    "path": "examples/pulid_flux_16bit_simple.json",
    "content": "{\r\n  \"last_node_id\": 64,\r\n  \"last_link_id\": 132,\r\n  \"nodes\": [\r\n    {\r\n      \"id\": 25,\r\n      \"type\": \"RandomNoise\",\r\n      \"pos\": {\r\n        \"0\": 6,\r\n        \"1\": -135\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 82\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 0,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"NOISE\",\r\n          \"type\": \"NOISE\",\r\n          \"links\": [\r\n            84\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"RandomNoise\"\r\n      },\r\n      \"widgets_values\": [\r\n        186462208016243,\r\n        \"fixed\"\r\n      ],\r\n      \"color\": \"#2a363b\",\r\n      \"bgcolor\": \"#3f5159\"\r\n    },\r\n    {\r\n      \"id\": 26,\r\n      \"type\": \"FluxGuidance\",\r\n      \"pos\": {\r\n        \"0\": 372,\r\n        \"1\": -171\r\n      },\r\n      \"size\": {\r\n        \"0\": 317.4000244140625,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {\r\n        \"collapsed\": false\r\n      },\r\n      \"order\": 13,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"conditioning\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"link\": 41\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"CONDITIONING\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"links\": [\r\n            107\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"FluxGuidance\"\r\n      },\r\n      \"widgets_values\": [\r\n        3.5\r\n      ],\r\n      \"color\": \"#233\",\r\n      \"bgcolor\": \"#355\"\r\n    },\r\n    {\r\n      \"id\": 6,\r\n      \"type\": \"CLIPTextEncode\",\r\n      \"pos\": {\r\n        \"0\": 372,\r\n        \"1\": -55\r\n      },\r\n      \"size\": {\r\n        \"0\": 422.84503173828125,\r\n        \"1\": 164.31304931640625\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 12,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"clip\",\r\n          \"type\": \"CLIP\",\r\n          \"link\": 132\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"CONDITIONING\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"links\": [\r\n            41\r\n          ],\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"title\": \"CLIP Text Encode (Positive Prompt)\",\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"CLIPTextEncode\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"Half body portrait of 60 years old guy, with an surprised expression, he is lost in vectors of AI models, sourounded by PC monitors and many cables, on his tshirt is a text with words printed in Arial font:\\\"PuLID Flux\\\", detailed, glowy background, photorealistic style with skin inperfections, looks like shot with an smartphone, skin details without plastic look, ASUS Keyboard.\"\r\n      ],\r\n      \"color\": \"#232\",\r\n      \"bgcolor\": \"#353\"\r\n    },\r\n    {\r\n      \"id\": 27,\r\n      \"type\": \"EmptySD3LatentImage\",\r\n      \"pos\": {\r\n        \"0\": 383,\r\n        \"1\": 155\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 106\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 1,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"LATENT\",\r\n          \"type\": \"LATENT\",\r\n          \"links\": [\r\n            86\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"EmptySD3LatentImage\"\r\n      },\r\n      \"widgets_values\": [\r\n        768,\r\n        1024,\r\n        1\r\n      ],\r\n      \"color\": \"#323\",\r\n      \"bgcolor\": \"#535\"\r\n    },\r\n    {\r\n      \"id\": 16,\r\n      \"type\": \"KSamplerSelect\",\r\n      \"pos\": {\r\n        \"0\": 384,\r\n        \"1\": 313\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 2,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"SAMPLER\",\r\n          \"type\": \"SAMPLER\",\r\n          \"links\": [\r\n            85\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"KSamplerSelect\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"euler\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 17,\r\n      \"type\": \"BasicScheduler\",\r\n      \"pos\": {\r\n        \"0\": 392,\r\n        \"1\": 424\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 106\r\n      },\r\n      \"flags\": {\r\n        \"collapsed\": false\r\n      },\r\n      \"order\": 11,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"model\",\r\n          \"type\": \"MODEL\",\r\n          \"link\": 131,\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"SIGMAS\",\r\n          \"type\": \"SIGMAS\",\r\n          \"links\": [\r\n            93\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"BasicScheduler\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"simple\",\r\n        10,\r\n        1\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 54,\r\n      \"type\": \"LoadImage\",\r\n      \"pos\": {\r\n        \"0\": 729,\r\n        \"1\": -490\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 314\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 3,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"IMAGE\",\r\n          \"type\": \"IMAGE\",\r\n          \"links\": [\r\n            126\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        },\r\n        {\r\n          \"name\": \"MASK\",\r\n          \"type\": \"MASK\",\r\n          \"links\": null,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"LoadImage\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"einstein.jpg\",\r\n        \"image\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 53,\r\n      \"type\": \"PulidFluxInsightFaceLoader\",\r\n      \"pos\": {\r\n        \"0\": 822,\r\n        \"1\": -80\r\n      },\r\n      \"size\": {\r\n        \"0\": 365.4000244140625,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 4,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"FACEANALYSIS\",\r\n          \"type\": \"FACEANALYSIS\",\r\n          \"links\": [\r\n            124\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PulidFluxInsightFaceLoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"CPU\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 51,\r\n      \"type\": \"PulidFluxEvaClipLoader\",\r\n      \"pos\": {\r\n        \"0\": 845,\r\n        \"1\": 52\r\n      },\r\n      \"size\": {\r\n        \"0\": 327.5999755859375,\r\n        \"1\": 26\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 5,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"EVA_CLIP\",\r\n          \"type\": \"EVA_CLIP\",\r\n          \"links\": [\r\n            123\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PulidFluxEvaClipLoader\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 45,\r\n      \"type\": \"PulidFluxModelLoader\",\r\n      \"pos\": {\r\n        \"0\": 846,\r\n        \"1\": 137\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 6,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"PULIDFLUX\",\r\n          \"type\": \"PULIDFLUX\",\r\n          \"links\": [\r\n            125\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PulidFluxModelLoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"pulid_flux_v0.9.0.safetensors\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 62,\r\n      \"type\": \"ApplyPulidFlux\",\r\n      \"pos\": {\r\n        \"0\": 842,\r\n        \"1\": 258\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 206\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 10,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"model\",\r\n          \"type\": \"MODEL\",\r\n          \"link\": 130\r\n        },\r\n        {\r\n          \"name\": \"pulid_flux\",\r\n          \"type\": \"PULIDFLUX\",\r\n          \"link\": 125\r\n        },\r\n        {\r\n          \"name\": \"eva_clip\",\r\n          \"type\": \"EVA_CLIP\",\r\n          \"link\": 123\r\n        },\r\n        {\r\n          \"name\": \"face_analysis\",\r\n          \"type\": \"FACEANALYSIS\",\r\n          \"link\": 124\r\n        },\r\n        {\r\n          \"name\": \"image\",\r\n          \"type\": \"IMAGE\",\r\n          \"link\": 126\r\n        },\r\n        {\r\n          \"name\": \"attn_mask\",\r\n          \"type\": \"MASK\",\r\n          \"link\": null\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"MODEL\",\r\n          \"type\": \"MODEL\",\r\n          \"links\": [\r\n            122\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"ApplyPulidFlux\"\r\n      },\r\n      \"widgets_values\": [\r\n        1,\r\n        0,\r\n        1\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 47,\r\n      \"type\": \"BasicGuider\",\r\n      \"pos\": {\r\n        \"0\": 1217,\r\n        \"1\": 401\r\n      },\r\n      \"size\": {\r\n        \"0\": 241.79998779296875,\r\n        \"1\": 46\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 14,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"model\",\r\n          \"type\": \"MODEL\",\r\n          \"link\": 122\r\n        },\r\n        {\r\n          \"name\": \"conditioning\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"link\": 107\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"GUIDER\",\r\n          \"type\": \"GUIDER\",\r\n          \"links\": [\r\n            83\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"BasicGuider\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 48,\r\n      \"type\": \"SamplerCustomAdvanced\",\r\n      \"pos\": {\r\n        \"0\": 1205,\r\n        \"1\": -39\r\n      },\r\n      \"size\": {\r\n        \"0\": 355.20001220703125,\r\n        \"1\": 326\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 15,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"noise\",\r\n          \"type\": \"NOISE\",\r\n          \"link\": 84\r\n        },\r\n        {\r\n          \"name\": \"guider\",\r\n          \"type\": \"GUIDER\",\r\n          \"link\": 83\r\n        },\r\n        {\r\n          \"name\": \"sampler\",\r\n          \"type\": \"SAMPLER\",\r\n          \"link\": 85\r\n        },\r\n        {\r\n          \"name\": \"sigmas\",\r\n          \"type\": \"SIGMAS\",\r\n          \"link\": 93\r\n        },\r\n        {\r\n          \"name\": \"latent_image\",\r\n          \"type\": \"LATENT\",\r\n          \"link\": 86\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"output\",\r\n          \"type\": \"LATENT\",\r\n          \"links\": [\r\n            87\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        },\r\n        {\r\n          \"name\": \"denoised_output\",\r\n          \"type\": \"LATENT\",\r\n          \"links\": null,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"SamplerCustomAdvanced\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 49,\r\n      \"type\": \"VAEDecode\",\r\n      \"pos\": {\r\n        \"0\": 1263,\r\n        \"1\": -137\r\n      },\r\n      \"size\": {\r\n        \"0\": 210,\r\n        \"1\": 46\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 16,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"samples\",\r\n          \"type\": \"LATENT\",\r\n          \"link\": 87\r\n        },\r\n        {\r\n          \"name\": \"vae\",\r\n          \"type\": \"VAE\",\r\n          \"link\": 88\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"IMAGE\",\r\n          \"type\": \"IMAGE\",\r\n          \"links\": [\r\n            89\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"VAEDecode\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 50,\r\n      \"type\": \"PreviewImage\",\r\n      \"pos\": {\r\n        \"0\": 1587,\r\n        \"1\": -169\r\n      },\r\n      \"size\": {\r\n        \"0\": 841.524169921875,\r\n        \"1\": 698.3060302734375\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 17,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"images\",\r\n          \"type\": \"IMAGE\",\r\n          \"link\": 89\r\n        }\r\n      ],\r\n      \"outputs\": [],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PreviewImage\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 63,\r\n      \"type\": \"UNETLoader\",\r\n      \"pos\": {\r\n        \"0\": 6,\r\n        \"1\": -7\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 82\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 7,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"MODEL\",\r\n          \"type\": \"MODEL\",\r\n          \"links\": [\r\n            130,\r\n            131\r\n          ],\r\n          \"shape\": 3,\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"UNETLoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"flux1-dev.safetensors\",\r\n        \"default\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 10,\r\n      \"type\": \"VAELoader\",\r\n      \"pos\": {\r\n        \"0\": 12,\r\n        \"1\": 285\r\n      },\r\n      \"size\": {\r\n        \"0\": 311.81634521484375,\r\n        \"1\": 60.429901123046875\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 8,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"VAE\",\r\n          \"type\": \"VAE\",\r\n          \"links\": [\r\n            88\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"VAELoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"flux1_vae.safetensors\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 64,\r\n      \"type\": \"DualCLIPLoader\",\r\n      \"pos\": {\r\n        \"0\": 8,\r\n        \"1\": 124\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 106\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 9,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"CLIP\",\r\n          \"type\": \"CLIP\",\r\n          \"links\": [\r\n            132\r\n          ],\r\n          \"shape\": 3,\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"DualCLIPLoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"t5xxl_fp16.safetensors\",\r\n        \"clip_l.safetensors\",\r\n        \"flux\"\r\n      ]\r\n    }\r\n  ],\r\n  \"links\": [\r\n    [\r\n      41,\r\n      6,\r\n      0,\r\n      26,\r\n      0,\r\n      \"CONDITIONING\"\r\n    ],\r\n    [\r\n      83,\r\n      47,\r\n      0,\r\n      48,\r\n      1,\r\n      \"GUIDER\"\r\n    ],\r\n    [\r\n      84,\r\n      25,\r\n      0,\r\n      48,\r\n      0,\r\n      \"NOISE\"\r\n    ],\r\n    [\r\n      85,\r\n      16,\r\n      0,\r\n      48,\r\n      2,\r\n      \"SAMPLER\"\r\n    ],\r\n    [\r\n      86,\r\n      27,\r\n      0,\r\n      48,\r\n      4,\r\n      \"LATENT\"\r\n    ],\r\n    [\r\n      87,\r\n      48,\r\n      0,\r\n      49,\r\n      0,\r\n      \"LATENT\"\r\n    ],\r\n    [\r\n      88,\r\n      10,\r\n      0,\r\n      49,\r\n      1,\r\n      \"VAE\"\r\n    ],\r\n    [\r\n      89,\r\n      49,\r\n      0,\r\n      50,\r\n      0,\r\n      \"IMAGE\"\r\n    ],\r\n    [\r\n      93,\r\n      17,\r\n      0,\r\n      48,\r\n      3,\r\n      \"SIGMAS\"\r\n    ],\r\n    [\r\n      107,\r\n      26,\r\n      0,\r\n      47,\r\n      1,\r\n      \"CONDITIONING\"\r\n    ],\r\n    [\r\n      122,\r\n      62,\r\n      0,\r\n      47,\r\n      0,\r\n      \"MODEL\"\r\n    ],\r\n    [\r\n      123,\r\n      51,\r\n      0,\r\n      62,\r\n      2,\r\n      \"EVA_CLIP\"\r\n    ],\r\n    [\r\n      124,\r\n      53,\r\n      0,\r\n      62,\r\n      3,\r\n      \"FACEANALYSIS\"\r\n    ],\r\n    [\r\n      125,\r\n      45,\r\n      0,\r\n      62,\r\n      1,\r\n      \"PULIDFLUX\"\r\n    ],\r\n    [\r\n      126,\r\n      54,\r\n      0,\r\n      62,\r\n      4,\r\n      \"IMAGE\"\r\n    ],\r\n    [\r\n      130,\r\n      63,\r\n      0,\r\n      62,\r\n      0,\r\n      \"MODEL\"\r\n    ],\r\n    [\r\n      131,\r\n      63,\r\n      0,\r\n      17,\r\n      0,\r\n      \"MODEL\"\r\n    ],\r\n    [\r\n      132,\r\n      64,\r\n      0,\r\n      6,\r\n      0,\r\n      \"CLIP\"\r\n    ]\r\n  ],\r\n  \"groups\": [],\r\n  \"config\": {},\r\n  \"extra\": {\r\n    \"ds\": {\r\n      \"scale\": 0.9090909090909091,\r\n      \"offset\": [\r\n        113.84966682267732,\r\n        547.8597243753773\r\n      ]\r\n    }\r\n  },\r\n  \"version\": 0.4\r\n}"
  },
  {
    "path": "examples/pulid_flux_8bitgguf_simple.json",
    "content": "{\r\n  \"last_node_id\": 62,\r\n  \"last_link_id\": 129,\r\n  \"nodes\": [\r\n    {\r\n      \"id\": 25,\r\n      \"type\": \"RandomNoise\",\r\n      \"pos\": {\r\n        \"0\": 6,\r\n        \"1\": -135\r\n      },\r\n      \"size\": [\r\n        315,\r\n        82\r\n      ],\r\n      \"flags\": {},\r\n      \"order\": 0,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"NOISE\",\r\n          \"type\": \"NOISE\",\r\n          \"links\": [\r\n            84\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"RandomNoise\"\r\n      },\r\n      \"widgets_values\": [\r\n        186462208016243,\r\n        \"fixed\"\r\n      ],\r\n      \"color\": \"#2a363b\",\r\n      \"bgcolor\": \"#3f5159\"\r\n    },\r\n    {\r\n      \"id\": 31,\r\n      \"type\": \"UnetLoaderGGUF\",\r\n      \"pos\": {\r\n        \"0\": 14,\r\n        \"1\": 5\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 1,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"MODEL\",\r\n          \"type\": \"MODEL\",\r\n          \"links\": [\r\n            127,\r\n            129\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"UnetLoaderGGUF\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"flux1-dev-Q8_0.gguf\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 41,\r\n      \"type\": \"DualCLIPLoaderGGUF\",\r\n      \"pos\": {\r\n        \"0\": 18,\r\n        \"1\": 114\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 106\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 2,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"CLIP\",\r\n          \"type\": \"CLIP\",\r\n          \"links\": [\r\n            128\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"DualCLIPLoaderGGUF\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"t5-v1_1-xxl-encoder-Q8_0.gguf\",\r\n        \"clip_l.safetensors\",\r\n        \"flux\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 10,\r\n      \"type\": \"VAELoader\",\r\n      \"pos\": {\r\n        \"0\": 23,\r\n        \"1\": 275\r\n      },\r\n      \"size\": {\r\n        \"0\": 311.81634521484375,\r\n        \"1\": 60.429901123046875\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 3,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"VAE\",\r\n          \"type\": \"VAE\",\r\n          \"links\": [\r\n            88\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"VAELoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"flux1_vae.safetensors\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 26,\r\n      \"type\": \"FluxGuidance\",\r\n      \"pos\": {\r\n        \"0\": 372,\r\n        \"1\": -171\r\n      },\r\n      \"size\": {\r\n        \"0\": 317.4000244140625,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {\r\n        \"collapsed\": false\r\n      },\r\n      \"order\": 13,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"conditioning\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"link\": 41\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"CONDITIONING\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"links\": [\r\n            107\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"FluxGuidance\"\r\n      },\r\n      \"widgets_values\": [\r\n        3.5\r\n      ],\r\n      \"color\": \"#233\",\r\n      \"bgcolor\": \"#355\"\r\n    },\r\n    {\r\n      \"id\": 6,\r\n      \"type\": \"CLIPTextEncode\",\r\n      \"pos\": {\r\n        \"0\": 372,\r\n        \"1\": -55\r\n      },\r\n      \"size\": {\r\n        \"0\": 422.84503173828125,\r\n        \"1\": 164.31304931640625\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 11,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"clip\",\r\n          \"type\": \"CLIP\",\r\n          \"link\": 128\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"CONDITIONING\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"links\": [\r\n            41\r\n          ],\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"title\": \"CLIP Text Encode (Positive Prompt)\",\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"CLIPTextEncode\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"Half body portrait of 60 years old guy, with an surprised expression, he is lost in vectors of AI models, sourounded by PC monitors and many cables, on his tshirt is a text with words printed in Arial font:\\\"PuLID Flux\\\", detailed, glowy background, photorealistic style with skin inperfections, looks like shot with an smartphone, skin details without plastic look, ASUS Keyboard.\"\r\n      ],\r\n      \"color\": \"#232\",\r\n      \"bgcolor\": \"#353\"\r\n    },\r\n    {\r\n      \"id\": 27,\r\n      \"type\": \"EmptySD3LatentImage\",\r\n      \"pos\": {\r\n        \"0\": 383,\r\n        \"1\": 155\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 106\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 4,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"LATENT\",\r\n          \"type\": \"LATENT\",\r\n          \"links\": [\r\n            86\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"EmptySD3LatentImage\"\r\n      },\r\n      \"widgets_values\": [\r\n        768,\r\n        1024,\r\n        1\r\n      ],\r\n      \"color\": \"#323\",\r\n      \"bgcolor\": \"#535\"\r\n    },\r\n    {\r\n      \"id\": 16,\r\n      \"type\": \"KSamplerSelect\",\r\n      \"pos\": {\r\n        \"0\": 384,\r\n        \"1\": 313\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 5,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"SAMPLER\",\r\n          \"type\": \"SAMPLER\",\r\n          \"links\": [\r\n            85\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"KSamplerSelect\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"euler\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 17,\r\n      \"type\": \"BasicScheduler\",\r\n      \"pos\": {\r\n        \"0\": 392,\r\n        \"1\": 424\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 106\r\n      },\r\n      \"flags\": {\r\n        \"collapsed\": false\r\n      },\r\n      \"order\": 10,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"model\",\r\n          \"type\": \"MODEL\",\r\n          \"link\": 129,\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"SIGMAS\",\r\n          \"type\": \"SIGMAS\",\r\n          \"links\": [\r\n            93\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"BasicScheduler\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"simple\",\r\n        10,\r\n        1\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 54,\r\n      \"type\": \"LoadImage\",\r\n      \"pos\": {\r\n        \"0\": 729,\r\n        \"1\": -490\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 314\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 6,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"IMAGE\",\r\n          \"type\": \"IMAGE\",\r\n          \"links\": [\r\n            126\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        },\r\n        {\r\n          \"name\": \"MASK\",\r\n          \"type\": \"MASK\",\r\n          \"links\": null,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"LoadImage\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"einstein.jpg\",\r\n        \"image\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 53,\r\n      \"type\": \"PulidFluxInsightFaceLoader\",\r\n      \"pos\": {\r\n        \"0\": 822,\r\n        \"1\": -80\r\n      },\r\n      \"size\": {\r\n        \"0\": 365.4000244140625,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 7,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"FACEANALYSIS\",\r\n          \"type\": \"FACEANALYSIS\",\r\n          \"links\": [\r\n            124\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PulidFluxInsightFaceLoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"CPU\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 51,\r\n      \"type\": \"PulidFluxEvaClipLoader\",\r\n      \"pos\": {\r\n        \"0\": 845,\r\n        \"1\": 52\r\n      },\r\n      \"size\": {\r\n        \"0\": 327.5999755859375,\r\n        \"1\": 26\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 8,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"EVA_CLIP\",\r\n          \"type\": \"EVA_CLIP\",\r\n          \"links\": [\r\n            123\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PulidFluxEvaClipLoader\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 45,\r\n      \"type\": \"PulidFluxModelLoader\",\r\n      \"pos\": {\r\n        \"0\": 846,\r\n        \"1\": 137\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 58\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 9,\r\n      \"mode\": 0,\r\n      \"inputs\": [],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"PULIDFLUX\",\r\n          \"type\": \"PULIDFLUX\",\r\n          \"links\": [\r\n            125\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PulidFluxModelLoader\"\r\n      },\r\n      \"widgets_values\": [\r\n        \"pulid_flux_v0.9.0.safetensors\"\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 62,\r\n      \"type\": \"ApplyPulidFlux\",\r\n      \"pos\": {\r\n        \"0\": 842,\r\n        \"1\": 258\r\n      },\r\n      \"size\": {\r\n        \"0\": 315,\r\n        \"1\": 206\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 12,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"model\",\r\n          \"type\": \"MODEL\",\r\n          \"link\": 127\r\n        },\r\n        {\r\n          \"name\": \"pulid_flux\",\r\n          \"type\": \"PULIDFLUX\",\r\n          \"link\": 125\r\n        },\r\n        {\r\n          \"name\": \"eva_clip\",\r\n          \"type\": \"EVA_CLIP\",\r\n          \"link\": 123\r\n        },\r\n        {\r\n          \"name\": \"face_analysis\",\r\n          \"type\": \"FACEANALYSIS\",\r\n          \"link\": 124\r\n        },\r\n        {\r\n          \"name\": \"image\",\r\n          \"type\": \"IMAGE\",\r\n          \"link\": 126\r\n        },\r\n        {\r\n          \"name\": \"attn_mask\",\r\n          \"type\": \"MASK\",\r\n          \"link\": null\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"MODEL\",\r\n          \"type\": \"MODEL\",\r\n          \"links\": [\r\n            122\r\n          ],\r\n          \"shape\": 3,\r\n          \"slot_index\": 0\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"ApplyPulidFlux\"\r\n      },\r\n      \"widgets_values\": [\r\n        1,\r\n        0,\r\n        1\r\n      ]\r\n    },\r\n    {\r\n      \"id\": 47,\r\n      \"type\": \"BasicGuider\",\r\n      \"pos\": {\r\n        \"0\": 1217,\r\n        \"1\": 401\r\n      },\r\n      \"size\": {\r\n        \"0\": 241.79998779296875,\r\n        \"1\": 46\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 14,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"model\",\r\n          \"type\": \"MODEL\",\r\n          \"link\": 122\r\n        },\r\n        {\r\n          \"name\": \"conditioning\",\r\n          \"type\": \"CONDITIONING\",\r\n          \"link\": 107\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"GUIDER\",\r\n          \"type\": \"GUIDER\",\r\n          \"links\": [\r\n            83\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"BasicGuider\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 48,\r\n      \"type\": \"SamplerCustomAdvanced\",\r\n      \"pos\": {\r\n        \"0\": 1205,\r\n        \"1\": -39\r\n      },\r\n      \"size\": {\r\n        \"0\": 355.20001220703125,\r\n        \"1\": 326\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 15,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"noise\",\r\n          \"type\": \"NOISE\",\r\n          \"link\": 84\r\n        },\r\n        {\r\n          \"name\": \"guider\",\r\n          \"type\": \"GUIDER\",\r\n          \"link\": 83\r\n        },\r\n        {\r\n          \"name\": \"sampler\",\r\n          \"type\": \"SAMPLER\",\r\n          \"link\": 85\r\n        },\r\n        {\r\n          \"name\": \"sigmas\",\r\n          \"type\": \"SIGMAS\",\r\n          \"link\": 93\r\n        },\r\n        {\r\n          \"name\": \"latent_image\",\r\n          \"type\": \"LATENT\",\r\n          \"link\": 86\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"output\",\r\n          \"type\": \"LATENT\",\r\n          \"links\": [\r\n            87\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        },\r\n        {\r\n          \"name\": \"denoised_output\",\r\n          \"type\": \"LATENT\",\r\n          \"links\": null,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"SamplerCustomAdvanced\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 49,\r\n      \"type\": \"VAEDecode\",\r\n      \"pos\": {\r\n        \"0\": 1263,\r\n        \"1\": -137\r\n      },\r\n      \"size\": {\r\n        \"0\": 210,\r\n        \"1\": 46\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 16,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"samples\",\r\n          \"type\": \"LATENT\",\r\n          \"link\": 87\r\n        },\r\n        {\r\n          \"name\": \"vae\",\r\n          \"type\": \"VAE\",\r\n          \"link\": 88\r\n        }\r\n      ],\r\n      \"outputs\": [\r\n        {\r\n          \"name\": \"IMAGE\",\r\n          \"type\": \"IMAGE\",\r\n          \"links\": [\r\n            89\r\n          ],\r\n          \"slot_index\": 0,\r\n          \"shape\": 3\r\n        }\r\n      ],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"VAEDecode\"\r\n      }\r\n    },\r\n    {\r\n      \"id\": 50,\r\n      \"type\": \"PreviewImage\",\r\n      \"pos\": {\r\n        \"0\": 1587,\r\n        \"1\": -169\r\n      },\r\n      \"size\": {\r\n        \"0\": 841.524169921875,\r\n        \"1\": 698.3060302734375\r\n      },\r\n      \"flags\": {},\r\n      \"order\": 17,\r\n      \"mode\": 0,\r\n      \"inputs\": [\r\n        {\r\n          \"name\": \"images\",\r\n          \"type\": \"IMAGE\",\r\n          \"link\": 89\r\n        }\r\n      ],\r\n      \"outputs\": [],\r\n      \"properties\": {\r\n        \"Node name for S&R\": \"PreviewImage\"\r\n      }\r\n    }\r\n  ],\r\n  \"links\": [\r\n    [\r\n      41,\r\n      6,\r\n      0,\r\n      26,\r\n      0,\r\n      \"CONDITIONING\"\r\n    ],\r\n    [\r\n      83,\r\n      47,\r\n      0,\r\n      48,\r\n      1,\r\n      \"GUIDER\"\r\n    ],\r\n    [\r\n      84,\r\n      25,\r\n      0,\r\n      48,\r\n      0,\r\n      \"NOISE\"\r\n    ],\r\n    [\r\n      85,\r\n      16,\r\n      0,\r\n      48,\r\n      2,\r\n      \"SAMPLER\"\r\n    ],\r\n    [\r\n      86,\r\n      27,\r\n      0,\r\n      48,\r\n      4,\r\n      \"LATENT\"\r\n    ],\r\n    [\r\n      87,\r\n      48,\r\n      0,\r\n      49,\r\n      0,\r\n      \"LATENT\"\r\n    ],\r\n    [\r\n      88,\r\n      10,\r\n      0,\r\n      49,\r\n      1,\r\n      \"VAE\"\r\n    ],\r\n    [\r\n      89,\r\n      49,\r\n      0,\r\n      50,\r\n      0,\r\n      \"IMAGE\"\r\n    ],\r\n    [\r\n      93,\r\n      17,\r\n      0,\r\n      48,\r\n      3,\r\n      \"SIGMAS\"\r\n    ],\r\n    [\r\n      107,\r\n      26,\r\n      0,\r\n      47,\r\n      1,\r\n      \"CONDITIONING\"\r\n    ],\r\n    [\r\n      122,\r\n      62,\r\n      0,\r\n      47,\r\n      0,\r\n      \"MODEL\"\r\n    ],\r\n    [\r\n      123,\r\n      51,\r\n      0,\r\n      62,\r\n      2,\r\n      \"EVA_CLIP\"\r\n    ],\r\n    [\r\n      124,\r\n      53,\r\n      0,\r\n      62,\r\n      3,\r\n      \"FACEANALYSIS\"\r\n    ],\r\n    [\r\n      125,\r\n      45,\r\n      0,\r\n      62,\r\n      1,\r\n      \"PULIDFLUX\"\r\n    ],\r\n    [\r\n      126,\r\n      54,\r\n      0,\r\n      62,\r\n      4,\r\n      \"IMAGE\"\r\n    ],\r\n    [\r\n      127,\r\n      31,\r\n      0,\r\n      62,\r\n      0,\r\n      \"MODEL\"\r\n    ],\r\n    [\r\n      128,\r\n      41,\r\n      0,\r\n      6,\r\n      0,\r\n      \"CLIP\"\r\n    ],\r\n    [\r\n      129,\r\n      31,\r\n      0,\r\n      17,\r\n      0,\r\n      \"MODEL\"\r\n    ]\r\n  ],\r\n  \"groups\": [],\r\n  \"config\": {},\r\n  \"extra\": {\r\n    \"ds\": {\r\n      \"scale\": 0.7513148009015777,\r\n      \"offset\": [\r\n        124.42912136813258,\r\n        743.5079061935592\r\n      ]\r\n    }\r\n  },\r\n  \"version\": 0.4\r\n}"
  },
  {
    "path": "online_train1.py",
    "content": "# supervised by a global average embedding, which is a biased estimation of the true embedding\r\n# use projection to enable a complex decoding\r\n# makes no big difference than mean so far, the decoding may not work 🤦‍\r\n\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch.optim as optim\r\nimport torch\r\nfrom tqdm import tqdm\r\nimport random\r\n\r\nclass Transform(nn.Module):\r\n    def __init__(self, n=2, token_size=32, input_dim=2048):\r\n        super().__init__()\r\n        \r\n        self.n=n\r\n        self.dim= input_dim*token_size\r\n        self.token_size=token_size\r\n        self.input_dim=input_dim\r\n        \r\n        self.weight = nn.Parameter(torch.ones(self.n,1),requires_grad=True)\r\n        \r\n        self.projections = nn.ModuleList([nn.Sequential(\r\n            nn.Linear(self.dim, 512),\r\n            nn.ReLU(),\r\n            nn.Linear(512, self.dim)\r\n        ) for _ in range(self.n)])\r\n        \r\n    def encode(self, x):\r\n        x = x.view(-1, self.dim)\r\n        x = self.weight*x\r\n        return x\r\n    \r\n    def decode(self, x):\r\n        out=[]\r\n        for i in range(self.n):\r\n            t = self.projections[i](x[i])\r\n            out.append(t)\r\n        x = torch.stack(out, dim=0)\r\n        x=x.view(self.n,self.token_size,self.input_dim)\r\n        x=torch.mean(x,dim=0)\r\n        return x\r\n    \r\n    def forward(self, x):\r\n        x = self.encode(x)\r\n        x = self.decode(x)\r\n        return x\r\n\r\ndef online_train(cond, device=\"cuda:1\",step=1000):\r\n    old_device=cond.device\r\n    dtype=cond.dtype\r\n    cond = cond.clone().to(device,torch.float32)\r\n    cond.requires_grad=False\r\n    torch.set_grad_enabled(True)\r\n    \r\n    print(\"online training, initializing model...\")\r\n    n=cond.shape[0]\r\n    model=Transform(n=n)\r\n    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)\r\n    criterion = nn.MSELoss()\r\n    model.to(device)\r\n    model.train()\r\n        \r\n    y=torch.mean(cond,dim=0)\r\n    \r\n    random.seed(42)\r\n    bar=tqdm(range(step))\r\n    for s in bar:\r\n        optimizer.zero_grad()\r\n        attack_weight=[random.uniform(0.5,1.5) for _ in range(n)]\r\n        attack_weight=torch.tensor(attack_weight)[:,None,None].to(device)\r\n        x=attack_weight*cond\r\n        output = model(x)\r\n        loss = criterion(output, y)\r\n        loss.backward()\r\n        optimizer.step()\r\n        bar.set_postfix(loss=loss.item())\r\n        \r\n    weight=model.weight\r\n    cond=weight[:,:,None]*cond\r\n    print(weight)\r\n    \r\n    print(\"online training, ending...\")\r\n    del model\r\n    del optimizer\r\n    \r\n    cond=torch.mean(cond,dim=0).unsqueeze(0)\r\n    return cond.to(old_device,dtype=dtype)"
  },
  {
    "path": "online_train2.py",
    "content": "# self-supervised learning, one of the embedding acts as the target, the other as the support\r\n# works nicely\r\n\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch.optim as optim\r\nimport torch\r\nfrom tqdm import tqdm\r\nimport random\r\n\r\nclass Transform(nn.Module):\r\n    def __init__(self, n=2, token_size=32, input_dim=2048):\r\n        super().__init__()\r\n        \r\n        self.n=n\r\n        self.token_size=token_size\r\n        \r\n        self.weight = nn.Parameter(torch.ones(self.n,self.token_size),requires_grad=True)\r\n        \r\n    def encode(self, x):\r\n        x = torch.einsum('bij,bi->ij', x, self.weight)\r\n        return x\r\n    \r\n    def forward(self, x):\r\n        x = self.encode(x)\r\n        return x\r\n    \r\ndef criterion(output, target, token_sample_rate=0.25):\r\n    t=target-output\r\n    t=torch.norm(t,dim=1)\r\n    s=random.sample(range(t.shape[0]),int(token_sample_rate*t.shape[0]))\r\n    return torch.mean(t[s])\r\n\r\ndef online_train(cond, device=\"cuda:1\",step=1000):\r\n    old_device=cond.device\r\n    dtype=cond.dtype\r\n    cond = cond.clone().to(device,torch.float32)\r\n    # cond.requires_grad=False\r\n    # torch.set_grad_enabled(True)\r\n    \r\n    y=cond[0,:,:]\r\n    cond=cond[1:,:,:]\r\n    \r\n    print(\"online training, initializing model...\")\r\n    n=cond.shape[0]\r\n    model=Transform(n=n)\r\n    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)\r\n    model.to(device)\r\n    model.train()\r\n    \r\n    random.seed(42)\r\n    bar=tqdm(range(step))\r\n    for s in bar:\r\n        optimizer.zero_grad()\r\n        x=cond\r\n        output = model(x)\r\n        loss = criterion(output, y)\r\n        loss.backward()\r\n        optimizer.step()\r\n        bar.set_postfix(loss=loss.item())\r\n        \r\n    weight=model.weight\r\n    print(weight)\r\n    cond=weight[:,:,None]*cond+y[None,:,:]*(1.0/n)\r\n    \r\n    print(\"online training, ending...\")\r\n    del model\r\n    del optimizer\r\n    \r\n    cond=torch.mean(cond,dim=0).unsqueeze(0)\r\n    return cond.to(old_device,dtype=dtype)"
  },
  {
    "path": "pulidflux.py",
    "content": "\r\nimport torch\r\nfrom torch import nn, Tensor\r\nfrom torchvision import transforms\r\nfrom torchvision.transforms import functional\r\nimport os\r\nimport logging\r\nimport folder_paths\r\nimport comfy.utils\r\nfrom comfy.ldm.flux.layers import timestep_embedding\r\nimport comfy.model_management\r\nfrom insightface.app import FaceAnalysis\r\nfrom facexlib.parsing import init_parsing_model\r\nfrom facexlib.utils.face_restoration_helper import FaceRestoreHelper\r\n\r\nimport torch.nn.functional as F\r\n\r\nfrom .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\r\nfrom .encoders_flux import IDFormer, PerceiverAttentionCA\r\n\r\nINSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, \"insightface\")\r\n\r\nMODELS_DIR = os.path.join(folder_paths.models_dir, \"pulid\")\r\nif \"pulid\" not in folder_paths.folder_names_and_paths:\r\n    current_paths = [MODELS_DIR]\r\nelse:\r\n    current_paths, _ = folder_paths.folder_names_and_paths[\"pulid\"]\r\nfolder_paths.folder_names_and_paths[\"pulid\"] = (current_paths, folder_paths.supported_pt_extensions)\r\n\r\nfrom .online_train2 import online_train\r\n\r\nclass PulidFluxModel(nn.Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n        self.double_interval = 2\r\n        self.single_interval = 4\r\n\r\n        # Init encoder\r\n        self.pulid_encoder = IDFormer()\r\n\r\n        # Init attention\r\n        num_ca = 19 // self.double_interval + 38 // self.single_interval\r\n        if 19 % self.double_interval != 0:\r\n            num_ca += 1\r\n        if 38 % self.single_interval != 0:\r\n            num_ca += 1\r\n        self.pulid_ca = nn.ModuleList([\r\n            PerceiverAttentionCA() for _ in range(num_ca)\r\n        ])\r\n\r\n    def from_pretrained(self, path: str):\r\n        state_dict = comfy.utils.load_torch_file(path, safe_load=True)\r\n        state_dict_dict = {}\r\n        for k, v in state_dict.items():\r\n            module = k.split('.')[0]\r\n            state_dict_dict.setdefault(module, {})\r\n            new_k = k[len(module) + 1:]\r\n            state_dict_dict[module][new_k] = v\r\n\r\n        for module in state_dict_dict:\r\n            getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)\r\n\r\n        del state_dict\r\n        del state_dict_dict\r\n\r\n    def get_embeds(self, face_embed, clip_embeds):\r\n        return self.pulid_encoder(face_embed, clip_embeds)\r\n\r\ndef forward_orig(\r\n    self,\r\n    img: Tensor,\r\n    img_ids: Tensor,\r\n    txt: Tensor,\r\n    txt_ids: Tensor,\r\n    timesteps: Tensor,\r\n    y: Tensor,\r\n    guidance: Tensor = None,\r\n    control=None,\r\n    transformer_options={},\r\n    attn_mask: Tensor = None,\r\n    **kwargs # so it won't break if we add more stuff in the future\r\n) -> Tensor:\r\n    device = comfy.model_management.get_torch_device()\r\n    patches_replace = transformer_options.get(\"patches_replace\", {})\r\n\r\n    if img.ndim != 3 or txt.ndim != 3:\r\n        raise ValueError(\"Input img and txt tensors must have 3 dimensions.\")\r\n\r\n    # running on sequences img\r\n    img = self.img_in(img)\r\n    vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))\r\n    if self.params.guidance_embed:\r\n        if guidance is None:\r\n            raise ValueError(\"Didn't get guidance strength for guidance distilled model.\")\r\n        vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))\r\n\r\n    vec = vec + self.vector_in(y)\r\n    txt = self.txt_in(txt)\r\n\r\n    ids = torch.cat((txt_ids, img_ids), dim=1)\r\n    pe = self.pe_embedder(ids)\r\n\r\n    ca_idx = 0\r\n    blocks_replace = patches_replace.get(\"dit\", {})\r\n    for i, block in enumerate(self.double_blocks):\r\n        if (\"double_block\", i) in blocks_replace:\r\n            def block_wrap(args):\r\n                out = {}\r\n                out[\"img\"], out[\"txt\"] = block(img=args[\"img\"],\r\n                                               txt=args[\"txt\"],\r\n                                               vec=args[\"vec\"],\r\n                                               pe=args[\"pe\"],\r\n                                               attn_mask=args.get(\"attn_mask\"))\r\n                return out\r\n\r\n            out = blocks_replace[(\"double_block\", i)]({\"img\": img,\r\n                                                       \"txt\": txt,\r\n                                                       \"vec\": vec,\r\n                                                       \"pe\": pe,\r\n                                                       \"attn_mask\": attn_mask},\r\n                                                      {\"original_block\": block_wrap})\r\n            txt = out[\"txt\"]\r\n            img = out[\"img\"]\r\n        else:\r\n            img, txt = block(img=img,\r\n                             txt=txt,\r\n                             vec=vec,\r\n                             pe=pe,\r\n                             attn_mask=attn_mask)\r\n\r\n        if control is not None: # Controlnet\r\n            control_i = control.get(\"input\")\r\n            if i < len(control_i):\r\n                add = control_i[i]\r\n                if add is not None:\r\n                    img += add\r\n\r\n        # PuLID attention\r\n        if self.pulid_data:\r\n            if i % self.pulid_double_interval == 0:\r\n                # Will calculate influence of all pulid nodes at once\r\n                for _, node_data in self.pulid_data.items():\r\n                    condition_start = node_data['sigma_start'] >= timesteps\r\n                    condition_end = timesteps >= node_data['sigma_end']\r\n                    condition = torch.logical_and(\r\n                        condition_start, condition_end).all()\r\n                    \r\n                    if condition:\r\n                        img = img + node_data['weight'] * self.pulid_ca[ca_idx].to(device)(node_data['embedding'], img)\r\n                ca_idx += 1\r\n\r\n    img = torch.cat((txt, img), 1)\r\n    for i, block in enumerate(self.single_blocks):\r\n        if (\"single_block\", i) in blocks_replace:\r\n            def block_wrap(args):\r\n                out = {}\r\n                out[\"img\"] = block(args[\"img\"],\r\n                                   vec=args[\"vec\"],\r\n                                   pe=args[\"pe\"],\r\n                                   attn_mask=args.get(\"attn_mask\"))\r\n                return out\r\n\r\n            out = blocks_replace[(\"single_block\", i)]({\"img\": img,\r\n                                                       \"vec\": vec,\r\n                                                       \"pe\": pe,\r\n                                                       \"attn_mask\": attn_mask}, \r\n                                                      {\"original_block\": block_wrap})\r\n            img = out[\"img\"]\r\n        else:\r\n            img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)\r\n\r\n        if control is not None: # Controlnet\r\n            control_o = control.get(\"output\")\r\n            if i < len(control_o):\r\n                add = control_o[i]\r\n                if add is not None:\r\n                    img[:, txt.shape[1] :, ...] += add\r\n\r\n\r\n        # PuLID attention\r\n        if self.pulid_data:\r\n            real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]\r\n            if i % self.pulid_single_interval == 0:\r\n                # Will calculate influence of all nodes at once\r\n                for _, node_data in self.pulid_data.items():\r\n                    condition_start = node_data['sigma_start'] >= timesteps\r\n                    condition_end = timesteps >= node_data['sigma_end']\r\n\r\n                    # Combine conditions and reduce to a single boolean\r\n                    condition = torch.logical_and(condition_start, condition_end).all()\r\n\r\n                    if condition:\r\n                        real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx].to(device)(node_data['embedding'], real_img)\r\n                ca_idx += 1\r\n            img = torch.cat((txt, real_img), 1)\r\n\r\n    img = img[:, txt.shape[1] :, ...]\r\n\r\n    img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)\r\n    return img\r\n\r\ndef tensor_to_image(tensor):\r\n    image = tensor.mul(255).clamp(0, 255).byte().cpu()\r\n    image = image[..., [2, 1, 0]].numpy()\r\n    return image\r\n\r\ndef image_to_tensor(image):\r\n    tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1)\r\n    tensor = tensor[..., [2, 1, 0]]\r\n    return tensor\r\n\r\ndef resize_with_pad(img, target_size): # image: 1, h, w, 3\r\n    img = img.permute(0, 3, 1, 2)\r\n    H, W = target_size\r\n    \r\n    h, w = img.shape[2], img.shape[3]\r\n    scale_h = H / h\r\n    scale_w = W / w\r\n    scale = min(scale_h, scale_w)\r\n\r\n    new_h = int(min(h * scale,H))\r\n    new_w = int(min(w * scale,W))\r\n    new_size = (new_h, new_w)\r\n    \r\n    img = F.interpolate(img, size=new_size, mode='bicubic', align_corners=False)\r\n    \r\n    pad_top = (H - new_h) // 2\r\n    pad_bottom = (H - new_h) - pad_top\r\n    pad_left = (W - new_w) // 2\r\n    pad_right = (W - new_w) - pad_left\r\n    img = F.pad(img, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)\r\n    \r\n    return img.permute(0, 2, 3, 1)\r\n\r\ndef to_gray(img):\r\n    x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]\r\n    x = x.repeat(1, 3, 1, 1)\r\n    return x\r\n\r\n\"\"\"\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n Nodes\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\"\"\"\r\n\r\nclass PulidFluxModelLoader:\r\n    @classmethod\r\n    def INPUT_TYPES(s):\r\n        return {\"required\": {\"pulid_file\": (folder_paths.get_filename_list(\"pulid\"), )}}\r\n\r\n    RETURN_TYPES = (\"PULIDFLUX\",)\r\n    FUNCTION = \"load_model\"\r\n    CATEGORY = \"pulid\"\r\n\r\n    def load_model(self, pulid_file):\r\n        model_path = folder_paths.get_full_path(\"pulid\", pulid_file)\r\n\r\n        # Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node\r\n        model = PulidFluxModel()\r\n\r\n        logging.info(\"Loading PuLID-Flux model.\")\r\n        model.from_pretrained(path=model_path)\r\n\r\n        return (model,)\r\n\r\nclass PulidFluxInsightFaceLoader:\r\n    @classmethod\r\n    def INPUT_TYPES(s):\r\n        return {\r\n            \"required\": {\r\n                \"provider\": ([\"CPU\", \"CUDA\", \"ROCM\"], ),\r\n            },\r\n        }\r\n\r\n    RETURN_TYPES = (\"FACEANALYSIS\",)\r\n    FUNCTION = \"load_insightface\"\r\n    CATEGORY = \"pulid\"\r\n\r\n    def load_insightface(self, provider):\r\n        model = FaceAnalysis(name=\"antelopev2\", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l\r\n        model.prepare(ctx_id=0, det_size=(640, 640))\r\n\r\n        return (model,)\r\n\r\nclass PulidFluxEvaClipLoader:\r\n    @classmethod\r\n    def INPUT_TYPES(s):\r\n        return {\r\n            \"required\": {},\r\n        }\r\n\r\n    RETURN_TYPES = (\"EVA_CLIP\",)\r\n    FUNCTION = \"load_eva_clip\"\r\n    CATEGORY = \"pulid\"\r\n\r\n    def load_eva_clip(self):\r\n        from .eva_clip.factory import create_model_and_transforms\r\n\r\n        model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True)\r\n\r\n        model = model.visual\r\n\r\n        eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN)\r\n        eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD)\r\n        if not isinstance(eva_transform_mean, (list, tuple)):\r\n            model[\"image_mean\"] = (eva_transform_mean,) * 3\r\n        if not isinstance(eva_transform_std, (list, tuple)):\r\n            model[\"image_std\"] = (eva_transform_std,) * 3\r\n\r\n        return (model,)\r\n\r\nclass ApplyPulidFlux:\r\n    @classmethod\r\n    def INPUT_TYPES(s):  \r\n        return {\r\n            \"required\": {\r\n                \"model\": (\"MODEL\", ),\r\n                \"pulid_flux\": (\"PULIDFLUX\", ),\r\n                \"eva_clip\": (\"EVA_CLIP\", ),\r\n                \"face_analysis\": (\"FACEANALYSIS\", ),\r\n                \"image\": (\"IMAGE\", ),\r\n                \"weight\": (\"FLOAT\", {\"default\": 1.0, \"min\": -1.0, \"max\": 5.0, \"step\": 0.05 }),\r\n                \"start_at\": (\"FLOAT\", {\"default\": 0.0, \"min\": 0.0, \"max\": 1.0, \"step\": 0.001 }),\r\n                \"end_at\": (\"FLOAT\", {\"default\": 1.0, \"min\": 0.0, \"max\": 1.0, \"step\": 0.001 }),\r\n                \"fusion\": ([\"mean\",\"concat\",\"max\",\"norm_id\",\"max_token\",\"auto_weight\",\"train_weight\"],),\r\n                \"fusion_weight_max\": (\"FLOAT\", {\"default\": 1.0, \"min\": 0.0, \"max\": 20.0, \"step\": 0.1 }),\r\n                \"fusion_weight_min\": (\"FLOAT\", {\"default\": 0.0, \"min\": 0.0, \"max\": 20.0, \"step\": 0.1 }),\r\n                \"train_step\": (\"INT\", {\"default\": 1000, \"min\": 0, \"max\": 20000, \"step\": 1 }),\r\n                \"use_gray\": (\"BOOLEAN\", {\"default\": True, \"label_on\": \"enabled\", \"label_off\": \"disabled\"}),\r\n            },\r\n            \"optional\": {\r\n                \"attn_mask\": (\"MASK\", ),\r\n                \"prior_image\": (\"IMAGE\",), # for train weight, as the target\r\n            },\r\n            \"hidden\": {\r\n                \"unique_id\": \"UNIQUE_ID\"\r\n            },\r\n        }\r\n\r\n    RETURN_TYPES = (\"MODEL\",)\r\n    FUNCTION = \"apply_pulid_flux\"\r\n    CATEGORY = \"pulid\"\r\n\r\n    def __init__(self):\r\n        self.pulid_data_dict = None\r\n\r\n    def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, prior_image=None,fusion=\"mean\", fusion_weight_max=1.0, fusion_weight_min=0.0, train_step=1000, use_gray=True, attn_mask=None, unique_id=None):\r\n        device = comfy.model_management.get_torch_device()\r\n        # Why should I care what args say, when the unet model has a different dtype?!\r\n        # Am I missing something?!\r\n        #dtype = comfy.model_management.unet_dtype()\r\n        dtype = model.model.diffusion_model.dtype\r\n        # For 8bit use bfloat16 (because ufunc_add_CUDA is not implemented)\r\n        if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:\r\n            dtype = torch.bfloat16\r\n\r\n        eva_clip.to(device, dtype=dtype)\r\n        pulid_flux.to(device, dtype=dtype)\r\n\r\n        # TODO: Add masking support!\r\n        if attn_mask is not None:\r\n            if attn_mask.dim() > 3:\r\n                attn_mask = attn_mask.squeeze(-1)\r\n            elif attn_mask.dim() < 3:\r\n                attn_mask = attn_mask.unsqueeze(0)\r\n            attn_mask = attn_mask.to(device, dtype=dtype)\r\n\r\n        if prior_image is not None:\r\n            prior_image = resize_with_pad(prior_image.to(image.device, dtype=image.dtype), target_size=(image.shape[1], image.shape[2]))\r\n            image=torch.cat((prior_image,image),dim=0)\r\n        image = tensor_to_image(image)\r\n\r\n        face_helper = FaceRestoreHelper(\r\n            upscale_factor=1,\r\n            face_size=512,\r\n            crop_ratio=(1, 1),\r\n            det_model='retinaface_resnet50',\r\n            save_ext='png',\r\n            device=device,\r\n        )\r\n\r\n        face_helper.face_parse = None\r\n        face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device)\r\n\r\n        bg_label = [0, 16, 18, 7, 8, 9, 14, 15]\r\n        cond = []\r\n\r\n        # Analyse multiple images at multiple sizes and combine largest area embeddings\r\n        for i in range(image.shape[0]):\r\n            # get insightface embeddings\r\n            iface_embeds = None\r\n            for size in [(size, size) for size in range(640, 256, -64)]:\r\n                face_analysis.det_model.input_size = size\r\n                face_info = face_analysis.get(image[i])\r\n                if face_info:\r\n                    # Only use the maximum face\r\n                    # Removed the reverse=True from original code because we need the largest area not the smallest one!\r\n                    # Sorts the list in ascending order (smallest to largest),\r\n                    # then selects the last element, which is the largest face\r\n                    face_info = sorted(face_info, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]\r\n                    iface_embeds = torch.from_numpy(face_info.embedding).unsqueeze(0).to(device, dtype=dtype)\r\n                    break\r\n            else:\r\n                # No face detected, skip this image\r\n                logging.warning(f'Warning: No face detected in image {str(i)}')\r\n                continue\r\n\r\n            # get eva_clip embeddings\r\n            face_helper.clean_all()\r\n            face_helper.read_image(image[i])\r\n            face_helper.get_face_landmarks_5(only_center_face=True)\r\n            face_helper.align_warp_face()\r\n\r\n            if len(face_helper.cropped_faces) == 0:\r\n                # No face detected, skip this image\r\n                continue\r\n\r\n            # Get aligned face image\r\n            align_face = face_helper.cropped_faces[0]\r\n            # Convert bgr face image to tensor\r\n            align_face = image_to_tensor(align_face).unsqueeze(0).permute(0, 3, 1, 2).to(device)\r\n            parsing_out = face_helper.face_parse(functional.normalize(align_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]\r\n            parsing_out = parsing_out.argmax(dim=1, keepdim=True)\r\n            bg = sum(parsing_out == i for i in bg_label).bool()\r\n            white_image = torch.ones_like(align_face)\r\n            # Only keep the face features\r\n            if use_gray:\r\n                _align_face = to_gray(align_face)\r\n            else:\r\n                _align_face = align_face\r\n            face_features_image = torch.where(bg, white_image, _align_face)\r\n\r\n            # Transform img before sending to eva_clip\r\n            # Apparently MPS only supports NEAREST interpolation?\r\n            face_features_image = functional.resize(face_features_image, eva_clip.image_size, transforms.InterpolationMode.BICUBIC if 'cuda' in device.type else transforms.InterpolationMode.NEAREST).to(device, dtype=dtype)\r\n            face_features_image = functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std)\r\n\r\n            # eva_clip\r\n            id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False)\r\n            id_cond_vit = id_cond_vit.to(device, dtype=dtype)\r\n            for idx in range(len(id_vit_hidden)):\r\n                id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype)\r\n\r\n            id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True))\r\n\r\n            # Combine embeddings\r\n            id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1)\r\n\r\n            # Pulid_encoder\r\n            cond.append(pulid_flux.get_embeds(id_cond, id_vit_hidden))\r\n\r\n        if not cond:\r\n            # No faces detected, return the original model\r\n            logging.warning(\"PuLID warning: No faces detected in any of the given images, returning unmodified model.\")\r\n            return (model,)\r\n\r\n        # fusion embeddings\r\n        if fusion == \"mean\":\r\n            cond = torch.cat(cond).to(device, dtype=dtype) # N,32,2048\r\n            if cond.shape[0] > 1:\r\n                cond = torch.mean(cond, dim=0, keepdim=True)\r\n        elif fusion == \"concat\":\r\n            cond = torch.cat(cond, dim=1).to(device, dtype=dtype)\r\n        elif fusion == \"max\":\r\n            cond = torch.cat(cond).to(device, dtype=dtype)\r\n            if cond.shape[0] > 1:\r\n                cond = torch.max(cond, dim=0, keepdim=True)[0]\r\n        elif fusion == \"norm_id\":\r\n            cond = torch.cat(cond).to(device, dtype=dtype)\r\n            if cond.shape[0] > 1:\r\n                norm=torch.norm(cond,dim=(1,2))\r\n                norm=norm/torch.sum(norm)\r\n                cond=torch.einsum(\"wij,w->ij\",cond,norm).unsqueeze(0)\r\n        elif fusion == \"max_token\":\r\n            cond = torch.cat(cond).to(device, dtype=dtype)\r\n            if cond.shape[0] > 1:\r\n                norm=torch.norm(cond,dim=2)\r\n                _,idx=torch.max(norm,dim=0)\r\n                cond=torch.stack([cond[j,i] for i,j in enumerate(idx)]).unsqueeze(0)\r\n        elif fusion == \"auto_weight\": # 🤔\r\n            cond = torch.cat(cond).to(device, dtype=dtype)\r\n            if cond.shape[0] > 1:\r\n                norm=torch.norm(cond,dim=2)\r\n                order=torch.argsort(norm,descending=False,dim=0)\r\n                regular_weight=torch.linspace(fusion_weight_min,fusion_weight_max,norm.shape[0]).to(device, dtype=dtype)\r\n\r\n                _cond=[]\r\n                for i in range(cond.shape[1]):\r\n                    o=order[:,i]\r\n                    _cond.append(torch.einsum('ij,i->j',cond[:,i,:],regular_weight[o]))\r\n                cond=torch.stack(_cond,dim=0).unsqueeze(0)\r\n        elif fusion == \"train_weight\":\r\n            cond = torch.cat(cond).to(device, dtype=dtype)\r\n            if cond.shape[0] > 1:\r\n                if train_step > 0:\r\n                    with torch.inference_mode(False):\r\n                        cond = online_train(cond, device=cond.device, step=train_step)\r\n                else:\r\n                    cond = torch.mean(cond, dim=0, keepdim=True)\r\n\r\n        sigma_start = model.get_model_object(\"model_sampling\").percent_to_sigma(start_at)\r\n        sigma_end = model.get_model_object(\"model_sampling\").percent_to_sigma(end_at)\r\n\r\n        # Patch the Flux model (original diffusion_model)\r\n        # Nah, I don't care for the official ModelPatcher because it's undocumented!\r\n        # I want the end result now, and I don’t mind if I break other custom nodes in the process. 😄\r\n        flux_model = model.model.diffusion_model\r\n        # Let's see if we already patched the underlying flux model, if not apply patch\r\n        if not hasattr(flux_model, \"pulid_ca\"):\r\n            # Add perceiver attention, variables and current node data (weight, embedding, sigma_start, sigma_end)\r\n            # The pulid_data is stored in Dict by unique node index,\r\n            # so we can chain multiple ApplyPulidFlux nodes!\r\n            flux_model.pulid_ca = pulid_flux.pulid_ca\r\n            flux_model.pulid_double_interval = pulid_flux.double_interval\r\n            flux_model.pulid_single_interval = pulid_flux.single_interval\r\n            flux_model.pulid_data = {}\r\n            # Replace model forward_orig with our own\r\n            new_method = forward_orig.__get__(flux_model, flux_model.__class__)\r\n            setattr(flux_model, 'forward_orig', new_method)\r\n\r\n        # Patch is already in place, add data (weight, embedding, sigma_start, sigma_end) under unique node index\r\n        flux_model.pulid_data[unique_id] = {\r\n            'weight': weight,\r\n            'embedding': cond,\r\n            'sigma_start': sigma_start,\r\n            'sigma_end': sigma_end,\r\n        }\r\n\r\n        # Keep a reference for destructor (if node is deleted the data will be deleted as well)\r\n        self.pulid_data_dict = {'data': flux_model.pulid_data, 'unique_id': unique_id}\r\n\r\n        return (model,)\r\n\r\n    def __del__(self):\r\n        # Destroy the data for this node\r\n        if self.pulid_data_dict:\r\n            del self.pulid_data_dict['data'][self.pulid_data_dict['unique_id']]\r\n            del self.pulid_data_dict\r\n\r\n\r\nNODE_CLASS_MAPPINGS = {\r\n    \"PulidFluxModelLoader\": PulidFluxModelLoader,\r\n    \"PulidFluxInsightFaceLoader\": PulidFluxInsightFaceLoader,\r\n    \"PulidFluxEvaClipLoader\": PulidFluxEvaClipLoader,\r\n    \"ApplyPulidFlux\": ApplyPulidFlux,\r\n}\r\n\r\nNODE_DISPLAY_NAME_MAPPINGS = {\r\n    \"PulidFluxModelLoader\": \"Load PuLID Flux Model\",\r\n    \"PulidFluxInsightFaceLoader\": \"Load InsightFace (PuLID Flux)\",\r\n    \"PulidFluxEvaClipLoader\": \"Load Eva Clip (PuLID Flux)\",\r\n    \"ApplyPulidFlux\": \"Apply PuLID Flux\",\r\n}\r\n"
  },
  {
    "path": "requirements.txt",
    "content": "facexlib\r\ninsightface\r\nonnxruntime\r\nonnxruntime-gpu\r\nftfy\r\ntimm\r\n"
  }
]