Repository: sipie800/ComfyUI-PuLID-Flux-Enhanced Branch: main Commit: edcb3af534b8 Files: 36 Total size: 266.0 KB Directory structure: gitextract_gqoxo65u/ ├── LICENSE ├── README.md ├── __init__.py ├── encoders_flux.py ├── eva_clip/ │ ├── __init__.py │ ├── constants.py │ ├── eva_vit_model.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs/ │ │ ├── EVA01-CLIP-B-16.json │ │ ├── EVA01-CLIP-g-14-plus.json │ │ ├── EVA01-CLIP-g-14.json │ │ ├── EVA02-CLIP-B-16.json │ │ ├── EVA02-CLIP-L-14-336.json │ │ ├── EVA02-CLIP-L-14.json │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ └── EVA02-CLIP-bigE-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── rope.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ └── utils.py ├── examples/ │ ├── flux_pulid_multi.json │ ├── pulid_flux_16bit_simple.json │ └── pulid_flux_8bitgguf_simple.json ├── online_train1.py ├── online_train2.py ├── pulidflux.py └── requirements.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # ComfyUI-PuLID-Flux-Enhanced adapted from https://github.com/balazik/ComfyUI-PuLID-Flux workflow: see example flux_pulid_multi.json ## oct.7 2025 Formally discontinued. You guys may just use i2i models like flux kontext/qwen image edit, they are just doing same thing or doing better than Pulid. ## update oct.28 2024 Add an optional prior image input for the node. When using the train_weight method, the prior image will act as the main id image, which will lead the other id images to sum up to an optimized id embedding. This prior was randomly choosen previously, now we can assign it. Leaving the prior image input empty is OK just as previous. Please choose the best id image in your mind as the prior, or just experiment around and see what happens. ![oct28](https://github.com/user-attachments/assets/6a481cd9-2836-4f6f-9ad5-7458356c332a) ## new features ### common fusion methods for multi-image input mean(official), concat, max...etc ### some further experimental fusion methods. using the norm of the conditions to weight them using the max norm token among images a novel very fast embeddings self-training methods(explained here: https://github.com/balazik/ComfyUI-PuLID-Flux/issues/28) ### switch between using gray image (official) and rgb. in some cases, using gray image will bring detail loss ![2024-10-12_204047](https://github.com/user-attachments/assets/0ae96170-2eff-44e9-a53a-6a7447dbc0f1) ## tricks make your generation better ### fusion method leverages many id images to enhance fidelity 1. Besides mean fusion, you can try max or max_token, which can boost some major feature of a face (like large eyes, special nose or sth). it can go distortion beyond fidelity though. 2. With train_weight method, you can train with less than 2000 steps to make a deeper fusion than the non-training methods. Be aware too many training steps will make the training crash to the prior image. ### additional notes 1. Flux is a high capacity base model, it even can cognize the input image in some super human way. for example, you can resize your high quality input image with lanczos method rather than nearest area or billinear. you get finer texture. Keep in mind that taking care of your input image is the thing when the base model is strong. 2. The best pulid weight is around 0.8-0.95 for flux pulid 0.9.0. 1.0 is not good. For 0.9.1, it's higher towards around 0.9-1.0. Nonetheless the 0.9.1 is not always better than 0.9.0. 3. The base model is flux-dev or its finetuning, and the precision does mean the thing. fp16 should always be sound. fp8 is OK. I won't recommend gguf or nf4 things. 4. Some of the finetuned flux dev model may have strong bias. for example, it may sway the faces to a certain human race. 5. Euler simple is always working. Euler beta give you higher quality especially if your input image is somewhat low quality. 6. If you wanna use 3rd party flux-d weight, better to use a merged one or with a lora weight, rather than a finetuned one. Full finetuning can hurt the connection between pulid and original flux-d base model. You can test by yourself though. ## basic notes for common users This is an experimental node. It can give enhanced result but I'm not promising basic instructions for users who barely know about python developing or AI developing. Please follow the comfyui instructions or https://github.com/balazik/ComfyUI-PuLID-Flux to enable usage. If you are just using SDXL pulid, you can use https://github.com/cubiq/PuLID_ComfyUI. Some of the installation instructions there may also help. ================================================ FILE: __init__.py ================================================ from .pulidflux import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] ================================================ FILE: encoders_flux.py ================================================ import math import torch import torch.nn as nn # FFN def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) def reshape_tensor(x, heads): bs, length, width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, heads, length, -1) return x class PerceiverAttentionCA(nn.Module): def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048): super().__init__() self.scale = dim_head ** -0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, seq_len, _ = latents.shape q = self.to_q(latents) k, v = self.to_kv(x).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) return self.to_out(out) class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None): super().__init__() self.scale = dim_head ** -0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, seq_len, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) return self.to_out(out) class IDFormer(nn.Module): """ - perceiver resampler like arch (compared with previous MLP-like arch) - we concat id embedding (generated by arcface) and query tokens as latents - latents will attend each other and interact with vit features through cross-attention - vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two IDFormer layers """ def __init__( self, dim=1024, depth=10, dim_head=64, heads=16, num_id_token=5, num_queries=32, output_dim=2048, ff_mult=4, ): super().__init__() self.num_id_token = num_id_token self.dim = dim self.num_queries = num_queries assert depth % 5 == 0 self.depth = depth // 5 scale = dim ** -0.5 self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale) self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim)) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) for i in range(5): setattr( self, f'mapping_{i}', nn.Sequential( nn.Linear(1024, 1024), nn.LayerNorm(1024), nn.LeakyReLU(), nn.Linear(1024, 1024), nn.LayerNorm(1024), nn.LeakyReLU(), nn.Linear(1024, dim), ), ) self.id_embedding_mapping = nn.Sequential( nn.Linear(1280, 1024), nn.LayerNorm(1024), nn.LeakyReLU(), nn.Linear(1024, 1024), nn.LayerNorm(1024), nn.LeakyReLU(), nn.Linear(1024, dim * num_id_token), ) def forward(self, x, y): latents = self.latents.repeat(x.size(0), 1, 1) x = self.id_embedding_mapping(x) x = x.reshape(-1, self.num_id_token, self.dim) latents = torch.cat((latents, x), dim=1) for i in range(5): vit_feature = getattr(self, f'mapping_{i}')(y[i]) ctx_feature = torch.cat((x, vit_feature), dim=1) for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]: latents = attn(ctx_feature, latents) + latents latents = ff(latents) + latents latents = latents[:, :self.num_queries] latents = latents @ self.proj_out return latents ================================================ FILE: eva_clip/__init__.py ================================================ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .loss import ClipLoss from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained from .tokenizer import SimpleTokenizer, tokenize from .transform import image_transform ================================================ FILE: eva_clip/constants.py ================================================ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) ================================================ FILE: eva_clip/eva_vit_model.py ================================================ # -------------------------------------------------------- # Adapted from https://github.com/microsoft/unilm/tree/master/beit # -------------------------------------------------------- import math import os from functools import partial import torch import torch.nn as nn import torch.nn.functional as F try: from timm.models.layers import drop_path, to_2tuple, trunc_normal_ except: from timm.layers import drop_path, to_2tuple, trunc_normal_ from .transformer import PatchDropout from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast if os.getenv('ENV_TYPE') == 'deepspeed': try: from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint except: from torch.utils.checkpoint import checkpoint else: from torch.utils.checkpoint import checkpoint try: import xformers import xformers.ops as xops XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0., subln=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.ffn_ln(x) x = self.fc2(x) x = self.drop(x) return x class SwiGLU(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., norm_layer=nn.LayerNorm, subln=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w1 = nn.Linear(in_features, hidden_features) self.w2 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() self.w3 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x1 = self.w1(x) x2 = self.w2(x) hidden = self.act(x1) * x2 x = self.ffn_ln(hidden) x = self.w3(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.subln = subln if self.subln: self.q_proj = nn.Linear(dim, all_head_dim, bias=False) self.k_proj = nn.Linear(dim, all_head_dim, bias=False) self.v_proj = nn.Linear(dim, all_head_dim, bias=False) else: self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() # self.proj = nn.Linear(all_head_dim, all_head_dim) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.xattn = xattn self.xattn_drop = attn_drop self.rope = rope def forward(self, x, rel_pos_bias=None, attn_mask=None): B, N, C = x.shape if self.subln: q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) k = F.linear(input=x, weight=self.k_proj.weight, bias=None) v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) else: qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C q, k, v = qkv[0], qkv[1], qkv[2] if self.rope: # slightly fast impl q_t = q[:, :, 1:, :] ro_q_t = self.rope(q_t) q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) k_t = k[:, :, 1:, :] ro_k_t = self.rope(k_t) k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) if self.xattn: q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) x = xops.memory_efficient_attention( q, k, v, p=self.xattn_drop, scale=self.scale, ) x = x.reshape(B, N, -1) x = self.inner_attn_ln(x) x = self.proj(x) x = self.proj_drop(x) else: q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) if rel_pos_bias is not None: attn = attn + rel_pos_bias.type_as(attn) if attn_mask is not None: attn_mask = attn_mask.bool() attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.inner_attn_ln(x) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False, subln=False, naiveswiglu=False): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if naiveswiglu: self.mlp = SwiGLU( in_features=dim, hidden_features=mlp_hidden_dim, subln=subln, norm_layer=norm_layer, ) else: self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop ) if init_values is not None and init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None self.postnorm = postnorm def forward(self, x, rel_pos_bias=None, attn_mask=None): if self.gamma_1 is None: if self.postnorm: x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) x = x + self.drop_path(self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: if self.postnorm: x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww class EVAVisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0., use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False, use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False, pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False): super().__init__() if not XFORMERS_IS_AVAILBLE: xattn = False self.image_size = img_size self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None if rope: half_head_dim = embed_dim // num_heads // 2 hw_seq_len = img_size // patch_size self.rope = VisionRotaryEmbeddingFast( dim=half_head_dim, pt_seq_len=pt_hw_seq_len, ft_seq_len=hw_seq_len if intp_freq else None, # patch_dropout=patch_dropout ) else: self.rope = None self.naiveswiglu = naiveswiglu dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu) for i in range(depth)]) self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) self.apply(self._init_weights) self.fix_init_weight() if isinstance(self.head, nn.Linear): trunc_normal_(self.head.weight, std=.02) self.head.weight.data.mul_(init_scale) self.head.bias.data.mul_(init_scale) # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() self.grad_checkpointing = grad_checkpointing def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) if self.naiveswiglu: rescale(layer.mlp.w3.weight.data, layer_id + 1) else: rescale(layer.mlp.fc2.weight.data, layer_id + 1) def get_cast_dtype(self) -> torch.dtype: return self.blocks[0].mlp.fc2.weight.dtype def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def get_num_layers(self): return len(self.blocks) def lock(self, unlocked_groups=0, freeze_bn_stats=False): assert unlocked_groups == 0, 'partial locking not currently supported for this model' for param in self.parameters(): param.requires_grad = False @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False): x = self.patch_embed(x) batch_size, seq_len, _ = x.size() if shuffle: idx = torch.randperm(x.shape[1]) + 1 zero = torch.LongTensor([0, ]) idx = torch.cat([zero, idx]) pos_embed = self.pos_embed[:, idx] cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if shuffle: x = x + pos_embed elif self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in if os.getenv('RoPE') == '1': if self.training and not isinstance(self.patch_dropout, nn.Identity): x, patch_indices_keep = self.patch_dropout(x) self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) else: self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) x = self.patch_dropout(x) else: x = self.patch_dropout(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None hidden_states = [] for idx, blk in enumerate(self.blocks): if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden: hidden_states.append(x) if self.grad_checkpointing: x = checkpoint(blk, x, (rel_pos_bias,)) else: x = blk(x, rel_pos_bias=rel_pos_bias) if not return_all_features: x = self.norm(x) if self.fc_norm is not None: return self.fc_norm(x.mean(1)), hidden_states else: return x[:, 0], hidden_states return x def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False): if return_all_features: return self.forward_features(x, return_all_features, return_hidden, shuffle) x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle) x = self.head(x) if return_hidden: return x, hidden_states return x ================================================ FILE: eva_clip/factory.py ================================================ import json import logging import os import pathlib import re from copy import deepcopy from pathlib import Path from typing import Optional, Tuple, Union, Dict, Any import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ get_cast_dtype from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model from .transform import image_transform from .tokenizer import HFTokenizer, tokenize from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS config_ext = ('.json',) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: config_files.extend(config_path.glob(f'*{ext}')) for cf in config_files: with open(cf, "r", encoding="utf8") as f: model_cfg = json.load(f) if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): _MODEL_CONFIGS[cf.stem] = model_cfg _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) _rescan_model_configs() # initial populate of model config registry def list_models(): """ enumerate available model architectures based on config files """ return list(_MODEL_CONFIGS.keys()) def add_model_config(path): """ add model config path or file and update registry """ if not isinstance(path, Path): path = Path(path) _MODEL_CONFIG_PATHS.append(path) _rescan_model_configs() def get_model_config(model_name): if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None def get_tokenizer(model_name): config = get_model_config(model_name) tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize return tokenizer # loading openai CLIP weights when is_openai=True for training def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]): if is_openai: model = torch.jit.load(checkpoint_path, map_location="cpu").eval() state_dict = model.state_dict() for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) else: checkpoint = torch.load(checkpoint_path, map_location=map_location) for mk in model_key.split('|'): if isinstance(checkpoint, dict) and mk in checkpoint: state_dict = checkpoint[mk] break else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'): state_dict = {k[7:]: v for k, v in state_dict.items()} for k in skip_list: if k in list(state_dict.keys()): logging.info(f"Removing key {k} from pretrained checkpoint") del state_dict[k] if os.getenv('RoPE') == '1': for k in list(state_dict.keys()): if 'freqs_cos' in k or 'freqs_sin' in k: del state_dict[k] return state_dict def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True): state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False) # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'): state_dict['logit_scale'] = state_dict['text.logit_scale'] del state_dict['text.logit_scale'] # resize_clip_pos_embed for CLIP and open CLIP if 'visual.positional_embedding' in state_dict: resize_clip_pos_embed(state_dict, model) # specified to eva_vit_model elif 'visual.pos_embed' in state_dict: resize_evaclip_pos_embed(state_dict, model) # resize_clip_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") return incompatible_keys def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) for k in list(state_dict.keys()): if not k.startswith('visual.'): del state_dict[k] for k in list(state_dict.keys()): if k.startswith('visual.'): new_k = k[7:] state_dict[new_k] = state_dict[k] del state_dict[k] return state_dict def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) for k in list(state_dict.keys()): if k.startswith('visual.'): del state_dict[k] return state_dict def get_pretrained_tag(pretrained_model): pretrained_model = pretrained_model.lower() if "laion" in pretrained_model or "open_clip" in pretrained_model: return "open_clip" elif "openai" in pretrained_model: return "clip" elif "eva" in pretrained_model and "clip" in pretrained_model: return "eva_clip" else: return "other" def load_pretrained_checkpoint( model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]): visual_tag = get_pretrained_tag(visual_model) text_tag = get_pretrained_tag(text_model) logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}") visual_incompatible_keys, text_incompatible_keys = None, None if visual_checkpoint_path: if visual_tag == "eva_clip" or visual_tag == "open_clip": visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list) elif visual_tag == "clip": visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list) else: visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) # resize_clip_pos_embed for CLIP and open CLIP if 'positional_embedding' in visual_state_dict: resize_visual_pos_embed(visual_state_dict, model) # specified to EVA model elif 'pos_embed' in visual_state_dict: resize_eva_pos_embed(visual_state_dict, model) visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict) logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}") logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}") if text_checkpoint_path: if text_tag == "eva_clip" or text_tag == "open_clip": text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list) elif text_tag == "clip": text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list) else: text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict) logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}") logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}") return visual_incompatible_keys, text_incompatible_keys def create_model( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_clip: bool = False, force_patch_dropout: Optional[float] = None, pretrained_image: str = '', pretrained_text: str = '', pretrained_hf: bool = True, pretrained_visual_model: str = None, pretrained_text_model: str = None, cache_dir: Optional[str] = None, skip_list: list = [], ): model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names if isinstance(device, str): device = torch.device(device) if pretrained and pretrained.lower() == 'openai': logging.info(f'Loading pretrained {model_name} from OpenAI.') model = load_openai_model( model_name, precision=precision, device=device, jit=jit, cache_dir=cache_dir, ) else: model_cfg = get_model_config(model_name) if model_cfg is not None: logging.info(f'Loaded {model_name} model config.') else: logging.error(f'Model config for {model_name} not found; available models {list_models()}.') raise RuntimeError(f'Model config for {model_name} not found.') if 'rope' in model_cfg.get('vision_cfg', {}): if model_cfg['vision_cfg']['rope']: os.environ['RoPE'] = "1" else: os.environ['RoPE'] = "0" if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models model_cfg["quick_gelu"] = True if force_patch_dropout is not None: # override the default patch dropout value model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout cast_dtype = get_cast_dtype(precision) custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg']) if custom_clip: if 'hf_model_name' in model_cfg.get('text_cfg', {}): model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_cfg = {} if pretrained: checkpoint_path = '' pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) elif os.path.exists(pretrained): checkpoint_path = pretrained if checkpoint_path: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False ) else: error_str = ( f'Pretrained weights ({pretrained}) not found for model {model_name}.' f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') logging.warning(error_str) raise RuntimeError(error_str) else: visual_checkpoint_path = '' text_checkpoint_path = '' if pretrained_image: pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image) if 'timm_model_name' in model_cfg.get('vision_cfg', {}): # pretrained weight loading for timm models set via vision_cfg model_cfg['vision_cfg']['timm_model_pretrained'] = True elif pretrained_image_cfg: visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir) elif os.path.exists(pretrained_image): visual_checkpoint_path = pretrained_image else: logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') if pretrained_text: pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text) if pretrained_image_cfg: text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir) elif os.path.exists(pretrained_text): text_checkpoint_path = pretrained_text else: logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') if visual_checkpoint_path: logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).') if text_checkpoint_path: logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).') if visual_checkpoint_path or text_checkpoint_path: load_pretrained_checkpoint( model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list ) if "fp16" in precision or "bf16" in precision: logging.info(f'convert precision to {precision}') model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16) model.to(device=device) # set image / mean metadata from pretrained_cfg if available, or use default model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD if jit: model = torch.jit.script(model) return model def create_model_and_transforms( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_clip: bool = False, force_patch_dropout: Optional[float] = None, pretrained_image: str = '', pretrained_text: str = '', pretrained_hf: bool = True, pretrained_visual_model: str = None, pretrained_text_model: str = None, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, skip_list: list = [], ): model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_clip=force_custom_clip, force_patch_dropout=force_patch_dropout, pretrained_image=pretrained_image, pretrained_text=pretrained_text, pretrained_hf=pretrained_hf, pretrained_visual_model=pretrained_visual_model, pretrained_text_model=pretrained_text_model, cache_dir=cache_dir, skip_list=skip_list, ) image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) preprocess_train = image_transform( model.visual.image_size, is_train=True, mean=image_mean, std=image_std ) preprocess_val = image_transform( model.visual.image_size, is_train=False, mean=image_mean, std=image_std ) return model, preprocess_train, preprocess_val def create_transforms( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_clip: bool = False, force_patch_dropout: Optional[float] = None, pretrained_image: str = '', pretrained_text: str = '', pretrained_hf: bool = True, pretrained_visual_model: str = None, pretrained_text_model: str = None, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, skip_list: list = [], ): model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_clip=force_custom_clip, force_patch_dropout=force_patch_dropout, pretrained_image=pretrained_image, pretrained_text=pretrained_text, pretrained_hf=pretrained_hf, pretrained_visual_model=pretrained_visual_model, pretrained_text_model=pretrained_text_model, cache_dir=cache_dir, skip_list=skip_list, ) image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) preprocess_train = image_transform( model.visual.image_size, is_train=True, mean=image_mean, std=image_std ) preprocess_val = image_transform( model.visual.image_size, is_train=False, mean=image_mean, std=image_std ) del model return preprocess_train, preprocess_val def create_model_from_pretrained( model_name: str, pretrained: str, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_clip: bool = False, force_patch_dropout: Optional[float] = None, return_transform: bool = True, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, is_frozen: bool = False, ): if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): raise RuntimeError( f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.' f' Use open_clip.list_pretrained() to find one.') model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_clip=force_custom_clip, force_patch_dropout=force_patch_dropout, cache_dir=cache_dir, ) if is_frozen: for param in model.parameters(): param.requires_grad = False if not return_transform: return model image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) preprocess = image_transform( model.visual.image_size, is_train=False, mean=image_mean, std=image_std ) return model, preprocess ================================================ FILE: eva_clip/hf_configs.py ================================================ # HF architecture dict: arch_dict = { # https://huggingface.co/docs/transformers/model_doc/roberta#roberta "roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig "xlm-roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 "mt5": { "config_names": { # unlimited seqlen # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 "context_length": "", "vocab_size": "vocab_size", "width": "d_model", "heads": "num_heads", "layers": "num_layers", "layer_attr": "block", "token_embeddings_attr": "embed_tokens" }, "pooler": "mean_pooler", }, "bert": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", } } ================================================ FILE: eva_clip/hf_model.py ================================================ """ huggingface model adapter Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. """ import re import torch import torch.nn as nn from torch.nn import functional as F from torch import TensorType try: import transformers from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ BaseModelOutputWithPoolingAndCrossAttentions except ImportError as e: transformers = None class BaseModelOutput: pass class PretrainedConfig: pass from .hf_configs import arch_dict # utils def _camel2snake(s): return re.sub(r'(? TensorType: # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) # attn_mask = (x != self.config.pad_token_id).long() # out = self.transformer( # input_ids=x, # attention_mask=attn_mask, # encoder_hidden_states = image_embeds, # encoder_attention_mask = image_atts, # ) # pooled_out = self.pooler(out, attn_mask) # return self.itm_proj(pooled_out) def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): if masked_indices is None: masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices[input_ids == self.tokenizer.pad_token_id] = False masked_indices[input_ids == self.tokenizer.cls_token_id] = False if targets is not None: targets[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices input_ids[indices_replaced] = self.tokenizer.mask_token_id # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) input_ids[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged if targets is not None: return input_ids, targets else: return input_ids def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): labels = input_ids.clone() attn_mask = (input_ids != self.config.pad_token_id).long() image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) probability_matrix = torch.full(labels.shape, mlm_probability) input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix = probability_matrix) mlm_output = self.transformer(input_ids, attention_mask = attn_mask, encoder_hidden_states = image_embeds, encoder_attention_mask = image_atts, return_dict = True, labels = labels, ) return mlm_output.loss # mlm_output = self.transformer(input_ids, # attention_mask = attn_mask, # encoder_hidden_states = image_embeds, # encoder_attention_mask = image_atts, # return_dict = True, # ).last_hidden_state # logits = self.mlm_proj(mlm_output) # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) # labels = labels[:, 1:].contiguous().view(-1) # mlm_loss = F.cross_entropy( # logits, # labels, # # label_smoothing=0.1, # ) # return mlm_loss def forward(self, x:TensorType) -> TensorType: attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask) pooled_out = self.pooler(out, attn_mask) return self.proj(pooled_out) def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): if not unlocked_layers: # full freezing for n, p in self.transformer.named_parameters(): p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False return encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") embeddings = getattr( self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) modules = [embeddings, *layer_list][:-unlocked_layers] # freeze layers for module in modules: for n, p in module.named_parameters(): p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.gradient_checkpointing_enable() def get_num_layers(self): encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) return len(layer_list) def init_parameters(self): pass ================================================ FILE: eva_clip/loss.py ================================================ import math import torch import torch.nn as nn from torch.nn import functional as F try: import torch.distributed.nn from torch import distributed as dist has_distributed = True except ImportError: has_distributed = False try: import horovod.torch as hvd except ImportError: hvd = None from timm.loss import LabelSmoothingCrossEntropy def gather_features( image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False ): assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' if use_horovod: assert hvd is not None, 'Please install horovod' if gather_with_grad: all_image_features = hvd.allgather(image_features) all_text_features = hvd.allgather(text_features) else: with torch.no_grad(): all_image_features = hvd.allgather(image_features) all_text_features = hvd.allgather(text_features) if not local_loss: # ensure grads for local rank when all_* features don't have a gradient gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) gathered_image_features[rank] = image_features gathered_text_features[rank] = text_features all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) else: # We gather tensors from all gpus if gather_with_grad: all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) else: gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] dist.all_gather(gathered_image_features, image_features) dist.all_gather(gathered_text_features, text_features) if not local_loss: # ensure grads for local rank when all_* features don't have a gradient gathered_image_features[rank] = image_features gathered_text_features[rank] = text_features all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) return all_image_features, all_text_features class ClipLoss(nn.Module): def __init__( self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1, use_horovod=False, smoothing=0., ): super().__init__() self.local_loss = local_loss self.gather_with_grad = gather_with_grad self.cache_labels = cache_labels self.rank = rank self.world_size = world_size self.use_horovod = use_horovod self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None # cache state self.prev_num_logits = 0 self.labels = {} def forward(self, image_features, text_features, logit_scale=1.): device = image_features.device if self.world_size > 1: all_image_features, all_text_features = gather_features( image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T logits_per_text = logit_scale * text_features @ all_image_features.T else: logits_per_image = logit_scale * all_image_features @ all_text_features.T logits_per_text = logits_per_image.T else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T # calculated ground-truth and cache if enabled num_logits = logits_per_image.shape[0] if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] if self.label_smoothing_cross_entropy: total_loss = ( self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels) ) / 2 else: total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 acc = None i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) acc = {"i2t": i2t_acc, "t2i": t2i_acc} return total_loss, acc ================================================ FILE: eva_clip/model.py ================================================ """ CLIP Model Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import os from dataclasses import dataclass from typing import Optional, Tuple, Union from functools import partial import numpy as np import torch import torch.nn.functional as F from torch import nn try: from .hf_model import HFTextEncoder except: HFTextEncoder = None from .modified_resnet import ModifiedResNet from .timm_model import TimmModel from .eva_vit_model import EVAVisionTransformer from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer try: from apex.normalization import FusedLayerNorm except: FusedLayerNorm = LayerNorm print("Nvidia APEX normalization not installed, using PyTorch LayerNorm") try: import xformers.ops as xops except ImportError: xops = None #print("Please 'pip install xformers'") @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) drop_path_rate: Optional[float] = None # drop path rate timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size qkv_bias: bool = True fusedLN: bool = False xattn: bool = False postnorm: bool = False rope: bool = False pt_hw_seq_len: int = 16 # 224/14 intp_freq: bool = False naiveswiglu: bool = False subln: bool = False @dataclass class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 width: int = 512 heads: int = 8 layers: int = 12 ls_init_value: Optional[float] = None # layer scale initial value hf_model_name: str = None hf_tokenizer_name: str = None hf_model_pretrained: bool = True proj: str = 'mlp' pooler_type: str = 'mean_pooler' masked_language_modeling: bool = False fusedLN: bool = False xattn: bool = False attn_mask: bool = True def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more # memory efficient in recent PyTorch releases (>= 1.10). # NOTE: timm models always use native GELU regardless of quick_gelu flag. act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.eva_model_name: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNorm visual = EVAVisionTransformer( img_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, num_classes=embed_dim, use_mean_pooling=vision_cfg.global_average_pool, #False init_values=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, embed_dim=vision_cfg.width, depth=vision_cfg.layers, num_heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, qkv_bias=vision_cfg.qkv_bias, drop_path_rate=vision_cfg.drop_path_rate, norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6), xattn=vision_cfg.xattn, rope=vision_cfg.rope, postnorm=vision_cfg.postnorm, pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14 intp_freq= vision_cfg.intp_freq, naiveswiglu= vision_cfg.naiveswiglu, subln= vision_cfg.subln ) elif vision_cfg.timm_model_name: visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size ) act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, global_average_pool=vision_cfg.global_average_pool, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return visual def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling ) else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNorm text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer, xattn=text_cfg.xattn, attn_mask=text_cfg.attn_mask, ) return text class CLIP(nn.Module): def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): super().__init__() self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): return {'logit_scale'} def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) return image_features, text_features, self.logit_scale.exp() class CustomCLIP(nn.Module): def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, itm_task: bool = False, ): super().__init__() self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): self.text.lock(unlocked_layers, freeze_layer_norm) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) @torch.jit.ignore def no_weight_decay(self): return {'logit_scale'} def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) return image_features, text_features, self.logit_scale.exp() def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr, None) if tensor is not None: tensor.data = tensor.data.to(dtype) if isinstance(l, nn.Parameter): l.data = l.data.to(dtype) for name in ["text_projection", "proj"]: if hasattr(l, name) and isinstance(l, nn.Parameter): attr = getattr(l, name, None) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', 'logit_scale' )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) text_cfg = CLIPTextCfg( context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers ) model = CLIP( embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model ================================================ FILE: eva_clip/model_configs/EVA01-CLIP-B-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 16, "eva_model_name": "eva-clip-b-16", "ls_init_value": 0.1, "drop_path_rate": 0.0 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: eva_clip/model_configs/EVA01-CLIP-g-14-plus.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 40, "width": 1408, "head_width": 88, "mlp_ratio": 4.3637, "patch_size": 14, "eva_model_name": "eva-clip-g-14-x", "drop_path_rate": 0, "xattn": true, "fusedLN": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24, "xattn": false, "fusedLN": true } } ================================================ FILE: eva_clip/model_configs/EVA01-CLIP-g-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 40, "width": 1408, "head_width": 88, "mlp_ratio": 4.3637, "patch_size": 14, "eva_model_name": "eva-clip-g-14-x", "drop_path_rate": 0.4, "xattn": true, "fusedLN": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "xattn": false, "fusedLN": true } } ================================================ FILE: eva_clip/model_configs/EVA02-CLIP-B-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "head_width": 64, "patch_size": 16, "mlp_ratio": 2.6667, "eva_model_name": "eva-clip-b-16-X", "drop_path_rate": 0.0, "xattn": true, "fusedLN": true, "rope": true, "pt_hw_seq_len": 16, "intp_freq": true, "naiveswiglu": true, "subln": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "xattn": true, "fusedLN": true } } ================================================ FILE: eva_clip/model_configs/EVA02-CLIP-L-14-336.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 336, "layers": 24, "width": 1024, "drop_path_rate": 0, "head_width": 64, "mlp_ratio": 2.6667, "patch_size": 14, "eva_model_name": "eva-clip-l-14-336", "xattn": true, "fusedLN": true, "rope": true, "pt_hw_seq_len": 16, "intp_freq": true, "naiveswiglu": true, "subln": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "xattn": false, "fusedLN": true } } ================================================ FILE: eva_clip/model_configs/EVA02-CLIP-L-14.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "drop_path_rate": 0, "head_width": 64, "mlp_ratio": 2.6667, "patch_size": 14, "eva_model_name": "eva-clip-l-14", "xattn": true, "fusedLN": true, "rope": true, "pt_hw_seq_len": 16, "intp_freq": true, "naiveswiglu": true, "subln": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "xattn": false, "fusedLN": true } } ================================================ FILE: eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 64, "width": 1792, "head_width": 112, "mlp_ratio": 8.571428571428571, "patch_size": 14, "eva_model_name": "eva-clip-4b-14-x", "drop_path_rate": 0, "xattn": true, "postnorm": true, "fusedLN": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1280, "heads": 20, "layers": 32, "xattn": false, "fusedLN": true } } ================================================ FILE: eva_clip/model_configs/EVA02-CLIP-bigE-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 64, "width": 1792, "head_width": 112, "mlp_ratio": 8.571428571428571, "patch_size": 14, "eva_model_name": "eva-clip-4b-14-x", "drop_path_rate": 0, "xattn": true, "postnorm": true, "fusedLN": true }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24, "xattn": false, "fusedLN": true } } ================================================ FILE: eva_clip/modified_resnet.py ================================================ from collections import OrderedDict import torch from torch import nn from torch.nn import functional as F from .utils import freeze_batch_norm_2d class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.act2 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.act3 = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential(OrderedDict([ ("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion)) ])) def forward(self, x: torch.Tensor): identity = x out = self.act1(self.bn1(self.conv1(x))) out = self.act2(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.act3(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0., out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, image_size=224, width=64): super().__init__() self.output_dim = output_dim self.image_size = image_size # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.act2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.act3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) self.init_parameters() def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def init_parameters(self): if self.attnpool is not None: std = self.attnpool.c_proj.in_features ** -0.5 nn.init.normal_(self.attnpool.q_proj.weight, std=std) nn.init.normal_(self.attnpool.k_proj.weight, std=std) nn.init.normal_(self.attnpool.v_proj.weight, std=std) nn.init.normal_(self.attnpool.c_proj.weight, std=std) for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) def lock(self, unlocked_groups=0, freeze_bn_stats=False): assert unlocked_groups == 0, 'partial locking not currently supported for this model' for param in self.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): # FIXME support for non-transformer pass def stem(self, x): x = self.act1(self.bn1(self.conv1(x))) x = self.act2(self.bn2(self.conv2(x))) x = self.act3(self.bn3(self.conv3(x))) x = self.avgpool(x) return x def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) return x ================================================ FILE: eva_clip/openai.py ================================================ """ OpenAI pretrained model functions Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import os import warnings from typing import List, Optional, Union import torch from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url __all__ = ["list_openai_models", "load_openai_model"] def list_openai_models() -> List[str]: """Returns the names of available CLIP models""" return list_pretrained_models_by_tag('openai') def load_openai_model( name: str, precision: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, jit: bool = True, cache_dir: Optional[str] = None, ): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict precision: str Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model (default) or more hackable non-JIT model. cache_dir : Optional[str] The directory to cache the downloaded model weights Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if precision is None: precision = 'fp32' if device == 'cpu' else 'fp16' if get_pretrained_url(name, 'openai'): model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) elif os.path.isfile(name): model_path = name else: raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() state_dict = None except RuntimeError: # loading saved state dict if jit: warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") jit = False state_dict = torch.load(model_path, map_location="cpu") if not jit: # Build a non-jit model from the OpenAI jitted model state dict cast_dtype = get_cast_dtype(precision) try: model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) except KeyError: sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) if precision.startswith('amp') or precision == 'fp32': model.float() elif precision == 'bf16': convert_weights_to_lp(model, dtype=torch.bfloat16) return model # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] def patch_device(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("prim::Constant"): if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_device) patch_device(model.encode_image) patch_device(model.encode_text) # patch dtype to float32 (typically for CPU) if precision == 'fp32': float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() def patch_float(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [1, 2]: # dtype can be the second or third argument to aten::to() if inputs[i].node()["value"] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) patch_float(model.encode_image) patch_float(model.encode_text) model.float() # ensure image_size attr available at consistent location for both jit and non-jit model.visual.image_size = model.input_resolution.item() return model ================================================ FILE: eva_clip/pretrained.py ================================================ import hashlib import os import urllib import warnings from functools import partial from typing import Dict, Union from tqdm import tqdm try: from huggingface_hub import hf_hub_download _has_hf_hub = True except ImportError: hf_hub_download = None _has_hf_hub = False def _pcfg(url='', hf_hub='', filename='', mean=None, std=None): return dict( url=url, hf_hub=hf_hub, mean=mean, std=std, ) _VITB32 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), laion2b_e16=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') ) _VITB32_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), ) _VITB16 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), ) _EVAB16 = dict( eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), ) _VITL14 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ) _EVAL14 = dict( eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), ) _VITL14_336 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), ) _EVAL14_336 = dict( eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), ) _VITH14 = dict( laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), ) _VITg14 = dict( laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), ) _EVAg14 = dict( eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), ) _EVAg14_PLUS = dict( eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), ) _VITbigG14 = dict( laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), ) _EVAbigE14 = dict( eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), ) _EVAbigE14_PLUS = dict( eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), ) _PRETRAINED = { # "ViT-B-32": _VITB32, "OpenaiCLIP-B-32": _VITB32, "OpenCLIP-B-32": _VITB32, # "ViT-B-32-quickgelu": _VITB32_quickgelu, "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, # "ViT-B-16": _VITB16, "OpenaiCLIP-B-16": _VITB16, "OpenCLIP-B-16": _VITB16, "EVA02-B-16": _EVAB16, "EVA02-CLIP-B-16": _EVAB16, # "ViT-B-16-plus-240": _VITB16_PLUS_240, "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, # "ViT-L-14": _VITL14, "OpenaiCLIP-L-14": _VITL14, "OpenCLIP-L-14": _VITL14, "EVA02-L-14": _EVAL14, "EVA02-CLIP-L-14": _EVAL14, # "ViT-L-14-336": _VITL14_336, "OpenaiCLIP-L-14-336": _VITL14_336, "EVA02-CLIP-L-14-336": _EVAL14_336, # "ViT-H-14": _VITH14, # "ViT-g-14": _VITg14, "OpenCLIP-H-14": _VITH14, "OpenCLIP-g-14": _VITg14, "EVA01-CLIP-g-14": _EVAg14, "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, # "ViT-bigG-14": _VITbigG14, "OpenCLIP-bigG-14": _VITbigG14, "EVA02-CLIP-bigE-14": _EVAbigE14, "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, } def _clean_tag(tag: str): # normalize pretrained tags return tag.lower().replace('-', '_') def list_pretrained(as_str: bool = False): """ returns list of pretrained models Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True """ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] def list_pretrained_models_by_tag(tag: str): """ return all models having the specified pretrain tag """ models = [] tag = _clean_tag(tag) for k in _PRETRAINED.keys(): if tag in _PRETRAINED[k]: models.append(k) return models def list_pretrained_tags_by_model(model: str): """ return all pretrain tags for the specified model architecture """ tags = [] if model in _PRETRAINED: tags.extend(_PRETRAINED[model].keys()) return tags def is_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return False return _clean_tag(tag) in _PRETRAINED[model] def get_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return {} model_pretrained = _PRETRAINED[model] return model_pretrained.get(_clean_tag(tag), {}) def get_pretrained_url(model: str, tag: str): cfg = get_pretrained_cfg(model, _clean_tag(tag)) return cfg.get('url', '') def download_pretrained_from_url( url: str, cache_dir: Union[str, None] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") os.makedirs(cache_dir, exist_ok=True) filename = os.path.basename(url) if 'openaipublic' in url: expected_sha256 = url.split("/")[-2] elif 'mlfoundations' in url: expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] else: expected_sha256 = '' download_target = os.path.join(cache_dir, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if expected_sha256: if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") else: return download_target with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') return _has_hf_hub def download_pretrained_from_hf( model_id: str, filename: str = 'open_clip_pytorch_model.bin', revision=None, cache_dir: Union[str, None] = None, ): has_hf_hub(True) cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) return cached_file def download_pretrained( cfg: Dict, force_hf_hub: bool = False, cache_dir: Union[str, None] = None, ): target = '' if not cfg: return target download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') if download_hf_hub and force_hf_hub: # use HF hub even if url exists download_url = '' if download_url: target = download_pretrained_from_url(download_url, cache_dir=cache_dir) elif download_hf_hub: has_hf_hub(True) # we assume the hf_hub entries in pretrained config combine model_id + filename in # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. model_id, filename = os.path.split(download_hf_hub) if filename: target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) else: target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) return target ================================================ FILE: eva_clip/rope.py ================================================ from math import pi import torch from torch import nn from einops import rearrange, repeat import logging def broadcat(tensors, dim = -1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim = dim) def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r = 2) x1, x2 = x.unbind(dim = -1) x = torch.stack((-x2, x1), dim = -1) return rearrange(x, '... d r -> ... (d r)') class VisionRotaryEmbedding(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, num_freqs = 1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: raise ValueError(f'unknown modality {freqs_for}') if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs_h = torch.einsum('..., f -> ... f', t, freqs) freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) freqs_w = torch.einsum('..., f -> ... f', t, freqs) freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) self.register_buffer("freqs_cos", freqs.cos()) self.register_buffer("freqs_sin", freqs.sin()) logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') def forward(self, t, start_index = 0): rot_dim = self.freqs_cos.shape[-1] end_index = start_index + rot_dim assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) return torch.cat((t_left, t, t_right), dim = -1) class VisionRotaryEmbeddingFast(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, num_freqs = 1, patch_dropout = 0. ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: raise ValueError(f'unknown modality {freqs_for}') if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs = torch.einsum('..., f -> ... f', t, freqs) freqs = repeat(freqs, '... n -> ... (n r)', r = 2) freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) self.patch_dropout = patch_dropout self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') def forward(self, t, patch_indices_keep=None): if patch_indices_keep is not None: batch = t.size()[0] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) freqs_cos = freqs_cos[batch_indices, patch_indices_keep] freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') freqs_sin = freqs_sin[batch_indices, patch_indices_keep] freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') return t * freqs_cos + rotate_half(t) * freqs_sin return t * self.freqs_cos + rotate_half(t) * self.freqs_sin ================================================ FILE: eva_clip/timm_model.py ================================================ """ timm model adapter Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. """ import logging from collections import OrderedDict import torch import torch.nn as nn try: import timm from timm.models.layers import Mlp, to_2tuple try: # old timm imports < 0.8.1 from timm.models.layers.attention_pool2d import RotAttentionPool2d from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d except ImportError: # new timm imports >= 0.8.1 from timm.layers import RotAttentionPool2d from timm.layers import AttentionPool2d as AbsAttentionPool2d except ImportError: timm = None from .utils import freeze_batch_norm_2d class TimmModel(nn.Module): """ timm model adapter # FIXME this adapter is a work in progress, may change in ways that break weight compat """ def __init__( self, model_name, embed_dim, image_size=224, pool='avg', proj='linear', proj_bias=False, drop=0., pretrained=False): super().__init__() if timm is None: raise RuntimeError("Please `pip install timm` to use timm models.") self.image_size = to_2tuple(image_size) self.trunk = timm.create_model(model_name, pretrained=pretrained) feat_size = self.trunk.default_cfg.get('pool_size', None) feature_ndim = 1 if not feat_size else 2 if pool in ('abs_attn', 'rot_attn'): assert feature_ndim == 2 # if attn pooling used, remove both classifier and default pool self.trunk.reset_classifier(0, global_pool='') else: # reset global pool if pool config set, otherwise leave as network default reset_kwargs = dict(global_pool=pool) if pool else {} self.trunk.reset_classifier(0, **reset_kwargs) prev_chs = self.trunk.num_features head_layers = OrderedDict() if pool == 'abs_attn': head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) prev_chs = embed_dim elif pool == 'rot_attn': head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim else: assert proj, 'projection layer needed if non-attention pooling is used.' # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used if proj == 'linear': head_layers['drop'] = nn.Dropout(drop) head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) self.head = nn.Sequential(head_layers) def lock(self, unlocked_groups=0, freeze_bn_stats=False): """ lock modules Args: unlocked_groups (int): leave last n layer groups unlocked (default: 0) """ if not unlocked_groups: # lock full model for param in self.trunk.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self.trunk) else: # NOTE: partial freeze requires latest timm (master) branch and is subject to change try: # FIXME import here until API stable and in an official release from timm.models.helpers import group_parameters, group_modules except ImportError: raise RuntimeError( 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') matcher = self.trunk.group_matcher() gparams = group_parameters(self.trunk, matcher) max_layer_id = max(gparams.keys()) max_layer_id = max_layer_id - unlocked_groups for group_idx in range(max_layer_id + 1): group = gparams[group_idx] for param in group: self.trunk.get_parameter(param).requires_grad = False if freeze_bn_stats: gmodules = group_modules(self.trunk, matcher, reverse=True) gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} freeze_batch_norm_2d(self.trunk, gmodules) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): try: self.trunk.set_grad_checkpointing(enable) except Exception as e: logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') def forward(self, x): x = self.trunk(x) x = self.head(x) return x ================================================ FILE: eva_clip/tokenizer.py ================================================ """ CLIP tokenizer Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import gzip import html import os from functools import lru_cache from typing import Union, List import ftfy import regex as re import torch # https://stackoverflow.com/q/62691279 import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @lru_cache() def default_bpe(): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) if not special_tokens: special_tokens = ['', ''] else: special_tokens = ['', ''] + special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {t:t for t in special_tokens} special = "|".join(special_tokens) self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) self.vocab_size = len(self.encoder) self.all_special_ids = [self.encoder[t] for t in special_tokens] def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + ( token[-1] + '',) pairs = get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text _tokenizer = SimpleTokenizer() def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder[""] eot_token = _tokenizer.encoder[""] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token result[i, :len(tokens)] = torch.tensor(tokens) return result class HFTokenizer: "HuggingFace tokenizer wrapper" def __init__(self, tokenizer_name:str): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] texts = [whitespace_clean(basic_clean(text)) for text in texts] input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids return input_ids ================================================ FILE: eva_clip/transform.py ================================================ from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torchvision.transforms.functional as F from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD class ResizeMaxSize(nn.Module): def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): super().__init__() if not isinstance(max_size, int): raise TypeError(f"Size should be int. Got {type(max_size)}") self.max_size = max_size self.interpolation = interpolation self.fn = min if fn == 'min' else min self.fill = fill def forward(self, img): if isinstance(img, torch.Tensor): height, width = img.shape[:2] else: width, height = img.size scale = self.max_size / float(max(height, width)) if scale != 1.0: new_size = tuple(round(dim * scale) for dim in (height, width)) img = F.resize(img, new_size, self.interpolation) pad_h = self.max_size - new_size[0] pad_w = self.max_size - new_size[1] img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) return img def _convert_to_rgb(image): return image.convert('RGB') # class CatGen(nn.Module): # def __init__(self, num=4): # self.num = num # def mixgen_batch(image, text): # batch_size = image.shape[0] # index = np.random.permutation(batch_size) # cat_images = [] # for i in range(batch_size): # # image mixup # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] # # text concat # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] # text = torch.stack(text) # return image, text def image_transform( image_size: int, is_train: bool, mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, resize_longest_max: bool = False, fill_color: int = 0, ): mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): mean = (mean,) * 3 std = std or OPENAI_DATASET_STD if not isinstance(std, (list, tuple)): std = (std,) * 3 if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: # for square size, pass size as int so that Resize() uses aspect preserving shortest edge image_size = image_size[0] normalize = Normalize(mean=mean, std=std) if is_train: return Compose([ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), _convert_to_rgb, ToTensor(), normalize, ]) else: if resize_longest_max: transforms = [ ResizeMaxSize(image_size, fill=fill_color) ] else: transforms = [ Resize(image_size, interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ] transforms.extend([ _convert_to_rgb, ToTensor(), normalize, ]) return Compose(transforms) ================================================ FILE: eva_clip/transformer.py ================================================ import os import logging from collections import OrderedDict import math from typing import Callable, Optional, Sequence import numpy as np import torch from torch import nn from torch.nn import functional as F try: from timm.models.layers import trunc_normal_ except: from timm.layers import trunc_normal_ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast from .utils import to_2tuple if os.getenv('ENV_TYPE') == 'deepspeed': try: import deepspeed from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint except: print("Please 'pip install deepspeed'") deepspeed = None from torch.utils.checkpoint import checkpoint else: from torch.utils.checkpoint import checkpoint try: import xformers.ops as xops except ImportError: xops = None print("Please 'pip install xformers'") class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: torch.Tensor): output = F.layer_norm( x.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ) return output.type_as(x) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm (with cast back to input dtype).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class QuickGELU(nn.Module): # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 """ def __init__(self, prob, exclude_first_token=True): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}") def forward(self, x): if not self.training or self.prob == 0.: return x if self.exclude_first_token: cls_tokens, x = x[:, :1], x[:, 1:] else: cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) batch = x.size()[0] num_tokens = x.size()[1] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] keep_prob = 1 - self.prob num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] if self.exclude_first_token: x = torch.cat((cls_tokens, x), dim=1) if self.training and os.getenv('RoPE') == '1': return x, patch_indices_keep return x def _in_projection_packed( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None, ): """ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726 """ E = q.size(-1) if k is v: if q is k: # self-attention return F.linear(q, w, b).chunk(3, dim=-1) else: # encoder-decoder attention w_q, w_kv = w.split([E, E * 2]) if b is None: b_q = b_kv = None else: b_q, b_kv = b.split([E, E * 2]) return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) else: w_q, w_k, w_v = w.chunk(3) if b is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1. / 0.01), attn_drop=0., proj_drop=0., xattn=False, rope=False ): super().__init__() self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.logit_scale_max = logit_scale_max # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) else: self.logit_scale = None self.attn_drop = nn.Dropout(attn_drop) if self.scale_heads: self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.head_scale = None self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) self.xattn = xattn self.xattn_drop = attn_drop self.rope = rope def forward(self, x, attn_mask: Optional[torch.Tensor] = None): L, N, C = x.shape q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) if self.xattn: q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) x = xops.memory_efficient_attention( q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None, ) else: q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn.view(N, self.num_heads, L, L) * logit_scale attn = attn.view(-1, L, L) else: q = q * self.scale attn = torch.bmm(q, k.transpose(-1, -2)) if attn_mask is not None: if attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) if self.head_scale is not None: x = x.view(N, self.num_heads, L, C) * self.head_scale x = x.view(-1, L, C) x = x.transpose(0, 1).reshape(L, N, C) x = self.out_proj(x) x = self.out_drop(x) return x class CustomAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1. / 0.01), attn_drop=0., proj_drop=0., xattn=False ): super().__init__() self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.logit_scale_max = logit_scale_max # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) else: self.logit_scale = None self.attn_drop = nn.Dropout(attn_drop) if self.scale_heads: self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.head_scale = None self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) self.xattn = xattn self.xattn_drop = attn_drop def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) N_q, B_q, C_q = q.shape N_k, B_k, C_k = k.shape N_v, B_v, C_v = v.shape if self.xattn: # B, N, C -> B, N, num_heads, C q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1) k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1) v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1) x = xops.memory_efficient_attention( q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None ) else: # B*H, L, C q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1) v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: # B*H, N_q, N_k attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale attn = attn.view(-1, N_q, N_k) else: q = q * self.scale attn = torch.bmm(q, k.transpose(-1, -2)) if attn_mask is not None: if attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) if self.head_scale is not None: x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale x = x.view(-1, N_q, C_q) x = x.transpose(0, 1).reshape(N_q, B_q, C_q) x = self.out_proj(x) x = self.out_drop(x) return x class CustomResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, scale_cosine_attn: bool = False, scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, cross_attn: bool = False, xattn: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1 self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1 self.attn = CustomAttention( d_model, n_head, qkv_bias=True, attn_drop=0., proj_drop=0., scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask))) q = q + self.ls_2(self.mlp(self.ln_2(q))) return q class CustomTransformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, scale_cosine_attn: bool = True, scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, cross_attn: bool = False, xattn: bool = False, ): super().__init__() self.width = width self.layers = layers self.grad_checkpointing = False self.xattn = xattn self.resblocks = nn.ModuleList([ CustomResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, scale_cosine_attn=scale_cosine_attn, scale_heads=scale_heads, scale_attn=scale_attn, scale_fc=scale_fc, cross_attn=cross_attn, xattn=xattn) for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None): if k is None and v is None: k = v = q for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): q = checkpoint(r, q, k, v, attn_mask) else: q = r(q, k, v, attn_mask=attn_mask) return q class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, xattn: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) if xattn: self.attn = Attention(d_model, n_head, xattn=True) else: self.attn = nn.MultiheadAttention(d_model, n_head) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.xattn = xattn def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None if self.xattn: return self.attn(x, attn_mask=attn_mask) return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, xattn: bool = False, ): super().__init__() self.width = width self.layers = layers self.grad_checkpointing = False self.resblocks = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) return x class VisionTransformer(nn.Module): def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, ls_init_value: float = None, patch_dropout: float = 0., global_average_pool: bool = False, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, xattn: bool = False, ): super().__init__() self.image_size = to_2tuple(image_size) self.patch_size = to_2tuple(patch_size) self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) self.output_dim = output_dim self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() self.ln_pre = norm_layer(width) self.transformer = Transformer( width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn ) self.global_average_pool = global_average_pool self.ln_post = norm_layer(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) def lock(self, unlocked_groups=0, freeze_bn_stats=False): for param in self.parameters(): param.requires_grad = False if unlocked_groups != 0: groups = [ [ self.conv1, self.class_embedding, self.positional_embedding, self.ln_pre, ], *self.transformer.resblocks[:-1], [ self.transformer.resblocks[-1], self.ln_post, ], self.proj, ] def _unlock(x): if isinstance(x, Sequence): for g in x: _unlock(g) else: if isinstance(x, torch.nn.Parameter): x.requires_grad = True else: for p in x.parameters(): p.requires_grad = True _unlock(groups[-unlocked_groups:]) def get_num_layers(self): return self.transformer.layers @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): return {'positional_embedding', 'class_embedding'} def forward(self, x: torch.Tensor, return_all_features: bool=False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.patch_dropout(x) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD if not return_all_features: if self.global_average_pool: x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1) else: x = x[:, 0] x = self.ln_post(x) if self.proj is not None: x = x @ self.proj return x class TextTransformer(nn.Module): def __init__( self, context_length: int = 77, vocab_size: int = 49408, width: int = 512, heads: int = 8, layers: int = 12, ls_init_value: float = None, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, xattn: bool= False, attn_mask: bool = True ): super().__init__() self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim self.token_embedding = nn.Embedding(vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn ) self.xattn = xattn self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) if attn_mask: self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) else: self.attn_mask = None self.init_parameters() def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): # return {'positional_embedding', 'token_embedding'} return {'positional_embedding'} def get_num_layers(self): return self.transformer.layers def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def forward(self, text, return_all_features: bool=False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) # x = self.transformer(x) # no attention mask is applied x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) if not return_all_features: # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x ================================================ FILE: eva_clip/utils.py ================================================ from itertools import repeat import collections.abc import logging import math import numpy as np import torch from torch import nn as nn from torchvision.ops.misc import FrozenBatchNorm2d import torch.nn.functional as F # open CLIP def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, align_corners=True, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict['visual.positional_embedding'] = new_pos_embed def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, align_corners=True, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict['positional_embedding'] = new_pos_embed def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): all_keys = list(state_dict.keys()) # interpolate position embedding if 'visual.pos_embed' in state_dict: pos_embed_checkpoint = state_dict['visual.pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.visual.patch_embed.num_patches num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['visual.pos_embed'] = new_pos_embed patch_embed_proj = state_dict['visual.patch_embed.proj.weight'] patch_size = model.visual.patch_embed.patch_size state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate( patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): all_keys = list(state_dict.keys()) # interpolate position embedding if 'pos_embed' in state_dict: pos_embed_checkpoint = state_dict['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.visual.patch_embed.num_patches num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['pos_embed'] = new_pos_embed patch_embed_proj = state_dict['patch_embed.proj.weight'] patch_size = model.visual.patch_embed.patch_size state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): all_keys = list(state_dict.keys()) for key in all_keys: if "relative_position_index" in key: state_dict.pop(key) if "relative_position_bias_table" in key: rel_pos_bias = state_dict[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = model.visual.state_dict()[key].size() dst_patch_shape = model.visual.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) src_size = int((src_num_pos - num_extra_tokens) ** 0.5) dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: print("Position interpolate for %s from %dx%d to %dx%d" % ( key, src_size, src_size, dst_size, dst_size)) extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) print("Original positions = %s" % str(x)) print("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(num_attn_heads): z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() f = F.interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) state_dict[key] = new_rel_pos_bias # interpolate position embedding if 'pos_embed' in state_dict: pos_embed_checkpoint = state_dict['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.visual.patch_embed.num_patches num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['pos_embed'] = new_pos_embed patch_embed_proj = state_dict['patch_embed.proj.weight'] patch_size = model.visual.patch_embed.patch_size state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) def freeze_batch_norm_2d(module, module_match={}, name=''): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and returned. Otherwise, the module is walked recursively and submodules are converted in place. Args: module (torch.nn.Module): Any PyTorch module. module_match (dict): Dictionary of full module names to freeze (all if empty) name (str): Full module name (prefix) Returns: torch.nn.Module: Resulting module Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module is_match = True if module_match: is_match = name in module_match if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): res = FrozenBatchNorm2d(module.num_features) res.num_features = module.num_features res.affine = module.affine if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for child_name, child in module.named_children(): full_child_name = '.'.join([name, child_name]) if name else child_name new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) return res # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = lambda n, x: _ntuple(n)(x) def is_logging(args): def is_global_master(args): return args.rank == 0 def is_local_master(args): return args.local_rank == 0 def is_master(args, local=False): return is_local_master(args) if local else is_global_master(args) return is_master class AllGather(torch.autograd.Function): """An autograd function that performs allgather on a tensor. Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ @staticmethod def forward(ctx, tensor, rank, world_size): tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)] torch.distributed.all_gather(tensors_gather, tensor) ctx.rank = rank ctx.batch_size = tensor.shape[0] return torch.cat(tensors_gather, 0) @staticmethod def backward(ctx, grad_output): return ( grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], None, None ) allgather = AllGather.apply ================================================ FILE: examples/flux_pulid_multi.json ================================================ { "last_node_id": 66, "last_link_id": 133, "nodes": [ { "id": 16, "type": "KSamplerSelect", "pos": { "0": 384, "1": 313 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [ { "name": "SAMPLER", "type": "SAMPLER", "links": [ 85 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "KSamplerSelect" }, "widgets_values": [ "euler" ] }, { "id": 10, "type": "VAELoader", "pos": { "0": 12, "1": 285 }, "size": { "0": 311.81634521484375, "1": 60.429901123046875 }, "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 88 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "VAELoader" }, "widgets_values": [ "ae.sft" ] }, { "id": 27, "type": "EmptySD3LatentImage", "pos": { "0": 383, "1": 155 }, "size": { "0": 315, "1": 106 }, "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [ { "name": "LATENT", "type": "LATENT", "links": [ 86 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "EmptySD3LatentImage" }, "widgets_values": [ 896, 1152, 1 ] }, { "id": 25, "type": "RandomNoise", "pos": { "0": 6, "1": -135 }, "size": { "0": 315, "1": 82 }, "flags": {}, "order": 3, "mode": 0, "inputs": [], "outputs": [ { "name": "NOISE", "type": "NOISE", "links": [ 84 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "RandomNoise" }, "widgets_values": [ 641817409332707, "randomize" ] }, { "id": 47, "type": "BasicGuider", "pos": { "0": 1088, "1": 366 }, "size": { "0": 241.79998779296875, "1": 46 }, "flags": {}, "order": 15, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 122 }, { "name": "conditioning", "type": "CONDITIONING", "link": 107 } ], "outputs": [ { "name": "GUIDER", "type": "GUIDER", "links": [ 83 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "BasicGuider" } }, { "id": 49, "type": "VAEDecode", "pos": { "0": 1168, "1": -111 }, "size": { "0": 210, "1": 46 }, "flags": {}, "order": 17, "mode": 0, "inputs": [ { "name": "samples", "type": "LATENT", "link": 87 }, { "name": "vae", "type": "VAE", "link": 88 } ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 89 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "VAEDecode" } }, { "id": 50, "type": "PreviewImage", "pos": { "0": 1502, "1": -451 }, "size": { "0": 1079.977783203125, "1": 1041.9154052734375 }, "flags": {}, "order": 18, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", "link": 89 } ], "outputs": [], "properties": { "Node name for S&R": "PreviewImage" } }, { "id": 63, "type": "UNETLoader", "pos": { "0": 6, "1": -7 }, "size": { "0": 315, "1": 82 }, "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [ { "name": "MODEL", "type": "MODEL", "links": [ 130, 131 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "UNETLoader" }, "widgets_values": [ "flux1-dev.safetensors", "default" ] }, { "id": 17, "type": "BasicScheduler", "pos": { "0": 392, "1": 424 }, "size": { "0": 315, "1": 106 }, "flags": { "collapsed": false }, "order": 11, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 131, "slot_index": 0 } ], "outputs": [ { "name": "SIGMAS", "type": "SIGMAS", "links": [ 93 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "BasicScheduler" }, "widgets_values": [ "simple", 20, 1 ] }, { "id": 6, "type": "CLIPTextEncode", "pos": { "0": 369, "1": -63 }, "size": { "0": 422.84503173828125, "1": 164.31304931640625 }, "flags": {}, "order": 12, "mode": 0, "inputs": [ { "name": "clip", "type": "CLIP", "link": 132 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 41 ], "slot_index": 0 } ], "title": "CLIP Text Encode (Positive Prompt)", "properties": { "Node name for S&R": "CLIPTextEncode" }, "widgets_values": [ "portrait, color, cinematic" ] }, { "id": 48, "type": "SamplerCustomAdvanced", "pos": { "0": 1128, "1": -12 }, "size": { "0": 355.20001220703125, "1": 326 }, "flags": {}, "order": 16, "mode": 0, "inputs": [ { "name": "noise", "type": "NOISE", "link": 84 }, { "name": "guider", "type": "GUIDER", "link": 83 }, { "name": "sampler", "type": "SAMPLER", "link": 85 }, { "name": "sigmas", "type": "SIGMAS", "link": 93 }, { "name": "latent_image", "type": "LATENT", "link": 86 } ], "outputs": [ { "name": "output", "type": "LATENT", "links": [ 87 ], "slot_index": 0, "shape": 3 }, { "name": "denoised_output", "type": "LATENT", "links": null, "shape": 3 } ], "properties": { "Node name for S&R": "SamplerCustomAdvanced" } }, { "id": 53, "type": "PulidFluxInsightFaceLoader", "pos": { "0": 799, "1": -172 }, "size": { "0": 365.4000244140625, "1": 58 }, "flags": {}, "order": 5, "mode": 0, "inputs": [], "outputs": [ { "name": "FACEANALYSIS", "type": "FACEANALYSIS", "links": [ 124 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxInsightFaceLoader" }, "widgets_values": [ "CPU" ] }, { "id": 26, "type": "FluxGuidance", "pos": { "0": 372, "1": -171 }, "size": { "0": 317.4000244140625, "1": 58 }, "flags": { "collapsed": false }, "order": 14, "mode": 0, "inputs": [ { "name": "conditioning", "type": "CONDITIONING", "link": 41 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 107 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "FluxGuidance" }, "widgets_values": [ 4 ] }, { "id": 64, "type": "DualCLIPLoader", "pos": { "0": 8, "1": 124 }, "size": { "0": 315, "1": 106 }, "flags": {}, "order": 6, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP", "type": "CLIP", "links": [ 132 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "DualCLIPLoader" }, "widgets_values": [ "t5xxl_fp8_e4m3fn.safetensors", "clip_l.safetensors", "flux" ] }, { "id": 66, "type": "LoadImagesFromDir //Inspire", "pos": { "0": 14, "1": 623 }, "size": { "0": 567, "1": 170 }, "flags": {}, "order": 7, "mode": 0, "inputs": [], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 133 ], "shape": 3, "slot_index": 0 }, { "name": "MASK", "type": "MASK", "links": null, "shape": 3 }, { "name": "INT", "type": "INT", "links": null, "shape": 3 } ], "properties": { "Node name for S&R": "LoadImagesFromDir //Inspire" }, "widgets_values": [ "", 0, 0, false ] }, { "id": 45, "type": "PulidFluxModelLoader", "pos": { "0": 788, "1": 42 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 8, "mode": 0, "inputs": [], "outputs": [ { "name": "PULIDFLUX", "type": "PULIDFLUX", "links": [ 125 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxModelLoader" }, "widgets_values": [ "pulid_flux_v0.9.0.safetensors" ] }, { "id": 51, "type": "PulidFluxEvaClipLoader", "pos": { "0": 799, "1": -60 }, "size": { "0": 327.5999755859375, "1": 26 }, "flags": {}, "order": 9, "mode": 0, "inputs": [], "outputs": [ { "name": "EVA_CLIP", "type": "EVA_CLIP", "links": [ 123 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxEvaClipLoader" } }, { "id": 65, "type": "Note", "pos": { "0": 797, "1": 565 }, "size": { "0": 278.80340576171875, "1": 167.5153045654297 }, "flags": {}, "order": 10, "mode": 0, "inputs": [], "outputs": [], "properties": {}, "widgets_values": [ "fusion_weight_max and min only works when choose auto_weight.\n\ntrain_step only works when choose train_weight" ] }, { "id": 62, "type": "ApplyPulidFlux", "pos": { "0": 740, "1": 174 }, "size": { "0": 315, "1": 326 }, "flags": {}, "order": 13, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 130 }, { "name": "pulid_flux", "type": "PULIDFLUX", "link": 125 }, { "name": "eva_clip", "type": "EVA_CLIP", "link": 123 }, { "name": "face_analysis", "type": "FACEANALYSIS", "link": 124 }, { "name": "image", "type": "IMAGE", "link": 133 }, { "name": "attn_mask", "type": "MASK", "link": null } ], "outputs": [ { "name": "MODEL", "type": "MODEL", "links": [ 122 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "ApplyPulidFlux" }, "widgets_values": [ 1, 0, 1, "mean", 1, 0, 1000, true ] } ], "links": [ [ 41, 6, 0, 26, 0, "CONDITIONING" ], [ 83, 47, 0, 48, 1, "GUIDER" ], [ 84, 25, 0, 48, 0, "NOISE" ], [ 85, 16, 0, 48, 2, "SAMPLER" ], [ 86, 27, 0, 48, 4, "LATENT" ], [ 87, 48, 0, 49, 0, "LATENT" ], [ 88, 10, 0, 49, 1, "VAE" ], [ 89, 49, 0, 50, 0, "IMAGE" ], [ 93, 17, 0, 48, 3, "SIGMAS" ], [ 107, 26, 0, 47, 1, "CONDITIONING" ], [ 122, 62, 0, 47, 0, "MODEL" ], [ 123, 51, 0, 62, 2, "EVA_CLIP" ], [ 124, 53, 0, 62, 3, "FACEANALYSIS" ], [ 125, 45, 0, 62, 1, "PULIDFLUX" ], [ 130, 63, 0, 62, 0, "MODEL" ], [ 131, 63, 0, 17, 0, "MODEL" ], [ 132, 64, 0, 6, 0, "CLIP" ], [ 133, 66, 0, 62, 4, "IMAGE" ] ], "groups": [], "config": {}, "extra": { "ds": { "scale": 0.6830134553650705, "offset": [ 237.9025120377926, 565.1585643260208 ] } }, "version": 0.4 } ================================================ FILE: examples/pulid_flux_16bit_simple.json ================================================ { "last_node_id": 64, "last_link_id": 132, "nodes": [ { "id": 25, "type": "RandomNoise", "pos": { "0": 6, "1": -135 }, "size": { "0": 315, "1": 82 }, "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [ { "name": "NOISE", "type": "NOISE", "links": [ 84 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "RandomNoise" }, "widgets_values": [ 186462208016243, "fixed" ], "color": "#2a363b", "bgcolor": "#3f5159" }, { "id": 26, "type": "FluxGuidance", "pos": { "0": 372, "1": -171 }, "size": { "0": 317.4000244140625, "1": 58 }, "flags": { "collapsed": false }, "order": 13, "mode": 0, "inputs": [ { "name": "conditioning", "type": "CONDITIONING", "link": 41 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 107 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "FluxGuidance" }, "widgets_values": [ 3.5 ], "color": "#233", "bgcolor": "#355" }, { "id": 6, "type": "CLIPTextEncode", "pos": { "0": 372, "1": -55 }, "size": { "0": 422.84503173828125, "1": 164.31304931640625 }, "flags": {}, "order": 12, "mode": 0, "inputs": [ { "name": "clip", "type": "CLIP", "link": 132 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 41 ], "slot_index": 0 } ], "title": "CLIP Text Encode (Positive Prompt)", "properties": { "Node name for S&R": "CLIPTextEncode" }, "widgets_values": [ "Half body portrait of 60 years old guy, with an surprised expression, he is lost in vectors of AI models, sourounded by PC monitors and many cables, on his tshirt is a text with words printed in Arial font:\"PuLID Flux\", detailed, glowy background, photorealistic style with skin inperfections, looks like shot with an smartphone, skin details without plastic look, ASUS Keyboard." ], "color": "#232", "bgcolor": "#353" }, { "id": 27, "type": "EmptySD3LatentImage", "pos": { "0": 383, "1": 155 }, "size": { "0": 315, "1": 106 }, "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [ { "name": "LATENT", "type": "LATENT", "links": [ 86 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "EmptySD3LatentImage" }, "widgets_values": [ 768, 1024, 1 ], "color": "#323", "bgcolor": "#535" }, { "id": 16, "type": "KSamplerSelect", "pos": { "0": 384, "1": 313 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [ { "name": "SAMPLER", "type": "SAMPLER", "links": [ 85 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "KSamplerSelect" }, "widgets_values": [ "euler" ] }, { "id": 17, "type": "BasicScheduler", "pos": { "0": 392, "1": 424 }, "size": { "0": 315, "1": 106 }, "flags": { "collapsed": false }, "order": 11, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 131, "slot_index": 0 } ], "outputs": [ { "name": "SIGMAS", "type": "SIGMAS", "links": [ 93 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "BasicScheduler" }, "widgets_values": [ "simple", 10, 1 ] }, { "id": 54, "type": "LoadImage", "pos": { "0": 729, "1": -490 }, "size": { "0": 315, "1": 314 }, "flags": {}, "order": 3, "mode": 0, "inputs": [], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 126 ], "slot_index": 0, "shape": 3 }, { "name": "MASK", "type": "MASK", "links": null, "shape": 3 } ], "properties": { "Node name for S&R": "LoadImage" }, "widgets_values": [ "einstein.jpg", "image" ] }, { "id": 53, "type": "PulidFluxInsightFaceLoader", "pos": { "0": 822, "1": -80 }, "size": { "0": 365.4000244140625, "1": 58 }, "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [ { "name": "FACEANALYSIS", "type": "FACEANALYSIS", "links": [ 124 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxInsightFaceLoader" }, "widgets_values": [ "CPU" ] }, { "id": 51, "type": "PulidFluxEvaClipLoader", "pos": { "0": 845, "1": 52 }, "size": { "0": 327.5999755859375, "1": 26 }, "flags": {}, "order": 5, "mode": 0, "inputs": [], "outputs": [ { "name": "EVA_CLIP", "type": "EVA_CLIP", "links": [ 123 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxEvaClipLoader" } }, { "id": 45, "type": "PulidFluxModelLoader", "pos": { "0": 846, "1": 137 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 6, "mode": 0, "inputs": [], "outputs": [ { "name": "PULIDFLUX", "type": "PULIDFLUX", "links": [ 125 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxModelLoader" }, "widgets_values": [ "pulid_flux_v0.9.0.safetensors" ] }, { "id": 62, "type": "ApplyPulidFlux", "pos": { "0": 842, "1": 258 }, "size": { "0": 315, "1": 206 }, "flags": {}, "order": 10, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 130 }, { "name": "pulid_flux", "type": "PULIDFLUX", "link": 125 }, { "name": "eva_clip", "type": "EVA_CLIP", "link": 123 }, { "name": "face_analysis", "type": "FACEANALYSIS", "link": 124 }, { "name": "image", "type": "IMAGE", "link": 126 }, { "name": "attn_mask", "type": "MASK", "link": null } ], "outputs": [ { "name": "MODEL", "type": "MODEL", "links": [ 122 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "ApplyPulidFlux" }, "widgets_values": [ 1, 0, 1 ] }, { "id": 47, "type": "BasicGuider", "pos": { "0": 1217, "1": 401 }, "size": { "0": 241.79998779296875, "1": 46 }, "flags": {}, "order": 14, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 122 }, { "name": "conditioning", "type": "CONDITIONING", "link": 107 } ], "outputs": [ { "name": "GUIDER", "type": "GUIDER", "links": [ 83 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "BasicGuider" } }, { "id": 48, "type": "SamplerCustomAdvanced", "pos": { "0": 1205, "1": -39 }, "size": { "0": 355.20001220703125, "1": 326 }, "flags": {}, "order": 15, "mode": 0, "inputs": [ { "name": "noise", "type": "NOISE", "link": 84 }, { "name": "guider", "type": "GUIDER", "link": 83 }, { "name": "sampler", "type": "SAMPLER", "link": 85 }, { "name": "sigmas", "type": "SIGMAS", "link": 93 }, { "name": "latent_image", "type": "LATENT", "link": 86 } ], "outputs": [ { "name": "output", "type": "LATENT", "links": [ 87 ], "slot_index": 0, "shape": 3 }, { "name": "denoised_output", "type": "LATENT", "links": null, "shape": 3 } ], "properties": { "Node name for S&R": "SamplerCustomAdvanced" } }, { "id": 49, "type": "VAEDecode", "pos": { "0": 1263, "1": -137 }, "size": { "0": 210, "1": 46 }, "flags": {}, "order": 16, "mode": 0, "inputs": [ { "name": "samples", "type": "LATENT", "link": 87 }, { "name": "vae", "type": "VAE", "link": 88 } ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 89 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "VAEDecode" } }, { "id": 50, "type": "PreviewImage", "pos": { "0": 1587, "1": -169 }, "size": { "0": 841.524169921875, "1": 698.3060302734375 }, "flags": {}, "order": 17, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", "link": 89 } ], "outputs": [], "properties": { "Node name for S&R": "PreviewImage" } }, { "id": 63, "type": "UNETLoader", "pos": { "0": 6, "1": -7 }, "size": { "0": 315, "1": 82 }, "flags": {}, "order": 7, "mode": 0, "inputs": [], "outputs": [ { "name": "MODEL", "type": "MODEL", "links": [ 130, 131 ], "shape": 3, "slot_index": 0 } ], "properties": { "Node name for S&R": "UNETLoader" }, "widgets_values": [ "flux1-dev.safetensors", "default" ] }, { "id": 10, "type": "VAELoader", "pos": { "0": 12, "1": 285 }, "size": { "0": 311.81634521484375, "1": 60.429901123046875 }, "flags": {}, "order": 8, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 88 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "VAELoader" }, "widgets_values": [ "flux1_vae.safetensors" ] }, { "id": 64, "type": "DualCLIPLoader", "pos": { "0": 8, "1": 124 }, "size": { "0": 315, "1": 106 }, "flags": {}, "order": 9, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP", "type": "CLIP", "links": [ 132 ], "shape": 3, "slot_index": 0 } ], "properties": { "Node name for S&R": "DualCLIPLoader" }, "widgets_values": [ "t5xxl_fp16.safetensors", "clip_l.safetensors", "flux" ] } ], "links": [ [ 41, 6, 0, 26, 0, "CONDITIONING" ], [ 83, 47, 0, 48, 1, "GUIDER" ], [ 84, 25, 0, 48, 0, "NOISE" ], [ 85, 16, 0, 48, 2, "SAMPLER" ], [ 86, 27, 0, 48, 4, "LATENT" ], [ 87, 48, 0, 49, 0, "LATENT" ], [ 88, 10, 0, 49, 1, "VAE" ], [ 89, 49, 0, 50, 0, "IMAGE" ], [ 93, 17, 0, 48, 3, "SIGMAS" ], [ 107, 26, 0, 47, 1, "CONDITIONING" ], [ 122, 62, 0, 47, 0, "MODEL" ], [ 123, 51, 0, 62, 2, "EVA_CLIP" ], [ 124, 53, 0, 62, 3, "FACEANALYSIS" ], [ 125, 45, 0, 62, 1, "PULIDFLUX" ], [ 126, 54, 0, 62, 4, "IMAGE" ], [ 130, 63, 0, 62, 0, "MODEL" ], [ 131, 63, 0, 17, 0, "MODEL" ], [ 132, 64, 0, 6, 0, "CLIP" ] ], "groups": [], "config": {}, "extra": { "ds": { "scale": 0.9090909090909091, "offset": [ 113.84966682267732, 547.8597243753773 ] } }, "version": 0.4 } ================================================ FILE: examples/pulid_flux_8bitgguf_simple.json ================================================ { "last_node_id": 62, "last_link_id": 129, "nodes": [ { "id": 25, "type": "RandomNoise", "pos": { "0": 6, "1": -135 }, "size": [ 315, 82 ], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [ { "name": "NOISE", "type": "NOISE", "links": [ 84 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "RandomNoise" }, "widgets_values": [ 186462208016243, "fixed" ], "color": "#2a363b", "bgcolor": "#3f5159" }, { "id": 31, "type": "UnetLoaderGGUF", "pos": { "0": 14, "1": 5 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [ { "name": "MODEL", "type": "MODEL", "links": [ 127, 129 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "UnetLoaderGGUF" }, "widgets_values": [ "flux1-dev-Q8_0.gguf" ] }, { "id": 41, "type": "DualCLIPLoaderGGUF", "pos": { "0": 18, "1": 114 }, "size": { "0": 315, "1": 106 }, "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP", "type": "CLIP", "links": [ 128 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "DualCLIPLoaderGGUF" }, "widgets_values": [ "t5-v1_1-xxl-encoder-Q8_0.gguf", "clip_l.safetensors", "flux" ] }, { "id": 10, "type": "VAELoader", "pos": { "0": 23, "1": 275 }, "size": { "0": 311.81634521484375, "1": 60.429901123046875 }, "flags": {}, "order": 3, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 88 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "VAELoader" }, "widgets_values": [ "flux1_vae.safetensors" ] }, { "id": 26, "type": "FluxGuidance", "pos": { "0": 372, "1": -171 }, "size": { "0": 317.4000244140625, "1": 58 }, "flags": { "collapsed": false }, "order": 13, "mode": 0, "inputs": [ { "name": "conditioning", "type": "CONDITIONING", "link": 41 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 107 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "FluxGuidance" }, "widgets_values": [ 3.5 ], "color": "#233", "bgcolor": "#355" }, { "id": 6, "type": "CLIPTextEncode", "pos": { "0": 372, "1": -55 }, "size": { "0": 422.84503173828125, "1": 164.31304931640625 }, "flags": {}, "order": 11, "mode": 0, "inputs": [ { "name": "clip", "type": "CLIP", "link": 128 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 41 ], "slot_index": 0 } ], "title": "CLIP Text Encode (Positive Prompt)", "properties": { "Node name for S&R": "CLIPTextEncode" }, "widgets_values": [ "Half body portrait of 60 years old guy, with an surprised expression, he is lost in vectors of AI models, sourounded by PC monitors and many cables, on his tshirt is a text with words printed in Arial font:\"PuLID Flux\", detailed, glowy background, photorealistic style with skin inperfections, looks like shot with an smartphone, skin details without plastic look, ASUS Keyboard." ], "color": "#232", "bgcolor": "#353" }, { "id": 27, "type": "EmptySD3LatentImage", "pos": { "0": 383, "1": 155 }, "size": { "0": 315, "1": 106 }, "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [ { "name": "LATENT", "type": "LATENT", "links": [ 86 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "EmptySD3LatentImage" }, "widgets_values": [ 768, 1024, 1 ], "color": "#323", "bgcolor": "#535" }, { "id": 16, "type": "KSamplerSelect", "pos": { "0": 384, "1": 313 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 5, "mode": 0, "inputs": [], "outputs": [ { "name": "SAMPLER", "type": "SAMPLER", "links": [ 85 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "KSamplerSelect" }, "widgets_values": [ "euler" ] }, { "id": 17, "type": "BasicScheduler", "pos": { "0": 392, "1": 424 }, "size": { "0": 315, "1": 106 }, "flags": { "collapsed": false }, "order": 10, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 129, "slot_index": 0 } ], "outputs": [ { "name": "SIGMAS", "type": "SIGMAS", "links": [ 93 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "BasicScheduler" }, "widgets_values": [ "simple", 10, 1 ] }, { "id": 54, "type": "LoadImage", "pos": { "0": 729, "1": -490 }, "size": { "0": 315, "1": 314 }, "flags": {}, "order": 6, "mode": 0, "inputs": [], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 126 ], "slot_index": 0, "shape": 3 }, { "name": "MASK", "type": "MASK", "links": null, "shape": 3 } ], "properties": { "Node name for S&R": "LoadImage" }, "widgets_values": [ "einstein.jpg", "image" ] }, { "id": 53, "type": "PulidFluxInsightFaceLoader", "pos": { "0": 822, "1": -80 }, "size": { "0": 365.4000244140625, "1": 58 }, "flags": {}, "order": 7, "mode": 0, "inputs": [], "outputs": [ { "name": "FACEANALYSIS", "type": "FACEANALYSIS", "links": [ 124 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxInsightFaceLoader" }, "widgets_values": [ "CPU" ] }, { "id": 51, "type": "PulidFluxEvaClipLoader", "pos": { "0": 845, "1": 52 }, "size": { "0": 327.5999755859375, "1": 26 }, "flags": {}, "order": 8, "mode": 0, "inputs": [], "outputs": [ { "name": "EVA_CLIP", "type": "EVA_CLIP", "links": [ 123 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxEvaClipLoader" } }, { "id": 45, "type": "PulidFluxModelLoader", "pos": { "0": 846, "1": 137 }, "size": { "0": 315, "1": 58 }, "flags": {}, "order": 9, "mode": 0, "inputs": [], "outputs": [ { "name": "PULIDFLUX", "type": "PULIDFLUX", "links": [ 125 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "PulidFluxModelLoader" }, "widgets_values": [ "pulid_flux_v0.9.0.safetensors" ] }, { "id": 62, "type": "ApplyPulidFlux", "pos": { "0": 842, "1": 258 }, "size": { "0": 315, "1": 206 }, "flags": {}, "order": 12, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 127 }, { "name": "pulid_flux", "type": "PULIDFLUX", "link": 125 }, { "name": "eva_clip", "type": "EVA_CLIP", "link": 123 }, { "name": "face_analysis", "type": "FACEANALYSIS", "link": 124 }, { "name": "image", "type": "IMAGE", "link": 126 }, { "name": "attn_mask", "type": "MASK", "link": null } ], "outputs": [ { "name": "MODEL", "type": "MODEL", "links": [ 122 ], "shape": 3, "slot_index": 0 } ], "properties": { "Node name for S&R": "ApplyPulidFlux" }, "widgets_values": [ 1, 0, 1 ] }, { "id": 47, "type": "BasicGuider", "pos": { "0": 1217, "1": 401 }, "size": { "0": 241.79998779296875, "1": 46 }, "flags": {}, "order": 14, "mode": 0, "inputs": [ { "name": "model", "type": "MODEL", "link": 122 }, { "name": "conditioning", "type": "CONDITIONING", "link": 107 } ], "outputs": [ { "name": "GUIDER", "type": "GUIDER", "links": [ 83 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "BasicGuider" } }, { "id": 48, "type": "SamplerCustomAdvanced", "pos": { "0": 1205, "1": -39 }, "size": { "0": 355.20001220703125, "1": 326 }, "flags": {}, "order": 15, "mode": 0, "inputs": [ { "name": "noise", "type": "NOISE", "link": 84 }, { "name": "guider", "type": "GUIDER", "link": 83 }, { "name": "sampler", "type": "SAMPLER", "link": 85 }, { "name": "sigmas", "type": "SIGMAS", "link": 93 }, { "name": "latent_image", "type": "LATENT", "link": 86 } ], "outputs": [ { "name": "output", "type": "LATENT", "links": [ 87 ], "slot_index": 0, "shape": 3 }, { "name": "denoised_output", "type": "LATENT", "links": null, "shape": 3 } ], "properties": { "Node name for S&R": "SamplerCustomAdvanced" } }, { "id": 49, "type": "VAEDecode", "pos": { "0": 1263, "1": -137 }, "size": { "0": 210, "1": 46 }, "flags": {}, "order": 16, "mode": 0, "inputs": [ { "name": "samples", "type": "LATENT", "link": 87 }, { "name": "vae", "type": "VAE", "link": 88 } ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 89 ], "slot_index": 0, "shape": 3 } ], "properties": { "Node name for S&R": "VAEDecode" } }, { "id": 50, "type": "PreviewImage", "pos": { "0": 1587, "1": -169 }, "size": { "0": 841.524169921875, "1": 698.3060302734375 }, "flags": {}, "order": 17, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", "link": 89 } ], "outputs": [], "properties": { "Node name for S&R": "PreviewImage" } } ], "links": [ [ 41, 6, 0, 26, 0, "CONDITIONING" ], [ 83, 47, 0, 48, 1, "GUIDER" ], [ 84, 25, 0, 48, 0, "NOISE" ], [ 85, 16, 0, 48, 2, "SAMPLER" ], [ 86, 27, 0, 48, 4, "LATENT" ], [ 87, 48, 0, 49, 0, "LATENT" ], [ 88, 10, 0, 49, 1, "VAE" ], [ 89, 49, 0, 50, 0, "IMAGE" ], [ 93, 17, 0, 48, 3, "SIGMAS" ], [ 107, 26, 0, 47, 1, "CONDITIONING" ], [ 122, 62, 0, 47, 0, "MODEL" ], [ 123, 51, 0, 62, 2, "EVA_CLIP" ], [ 124, 53, 0, 62, 3, "FACEANALYSIS" ], [ 125, 45, 0, 62, 1, "PULIDFLUX" ], [ 126, 54, 0, 62, 4, "IMAGE" ], [ 127, 31, 0, 62, 0, "MODEL" ], [ 128, 41, 0, 6, 0, "CLIP" ], [ 129, 31, 0, 17, 0, "MODEL" ] ], "groups": [], "config": {}, "extra": { "ds": { "scale": 0.7513148009015777, "offset": [ 124.42912136813258, 743.5079061935592 ] } }, "version": 0.4 } ================================================ FILE: online_train1.py ================================================ # supervised by a global average embedding, which is a biased estimation of the true embedding # use projection to enable a complex decoding # makes no big difference than mean so far, the decoding may not work 🤦‍ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch from tqdm import tqdm import random class Transform(nn.Module): def __init__(self, n=2, token_size=32, input_dim=2048): super().__init__() self.n=n self.dim= input_dim*token_size self.token_size=token_size self.input_dim=input_dim self.weight = nn.Parameter(torch.ones(self.n,1),requires_grad=True) self.projections = nn.ModuleList([nn.Sequential( nn.Linear(self.dim, 512), nn.ReLU(), nn.Linear(512, self.dim) ) for _ in range(self.n)]) def encode(self, x): x = x.view(-1, self.dim) x = self.weight*x return x def decode(self, x): out=[] for i in range(self.n): t = self.projections[i](x[i]) out.append(t) x = torch.stack(out, dim=0) x=x.view(self.n,self.token_size,self.input_dim) x=torch.mean(x,dim=0) return x def forward(self, x): x = self.encode(x) x = self.decode(x) return x def online_train(cond, device="cuda:1",step=1000): old_device=cond.device dtype=cond.dtype cond = cond.clone().to(device,torch.float32) cond.requires_grad=False torch.set_grad_enabled(True) print("online training, initializing model...") n=cond.shape[0] model=Transform(n=n) optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001) criterion = nn.MSELoss() model.to(device) model.train() y=torch.mean(cond,dim=0) random.seed(42) bar=tqdm(range(step)) for s in bar: optimizer.zero_grad() attack_weight=[random.uniform(0.5,1.5) for _ in range(n)] attack_weight=torch.tensor(attack_weight)[:,None,None].to(device) x=attack_weight*cond output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() bar.set_postfix(loss=loss.item()) weight=model.weight cond=weight[:,:,None]*cond print(weight) print("online training, ending...") del model del optimizer cond=torch.mean(cond,dim=0).unsqueeze(0) return cond.to(old_device,dtype=dtype) ================================================ FILE: online_train2.py ================================================ # self-supervised learning, one of the embedding acts as the target, the other as the support # works nicely import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch from tqdm import tqdm import random class Transform(nn.Module): def __init__(self, n=2, token_size=32, input_dim=2048): super().__init__() self.n=n self.token_size=token_size self.weight = nn.Parameter(torch.ones(self.n,self.token_size),requires_grad=True) def encode(self, x): x = torch.einsum('bij,bi->ij', x, self.weight) return x def forward(self, x): x = self.encode(x) return x def criterion(output, target, token_sample_rate=0.25): t=target-output t=torch.norm(t,dim=1) s=random.sample(range(t.shape[0]),int(token_sample_rate*t.shape[0])) return torch.mean(t[s]) def online_train(cond, device="cuda:1",step=1000): old_device=cond.device dtype=cond.dtype cond = cond.clone().to(device,torch.float32) # cond.requires_grad=False # torch.set_grad_enabled(True) y=cond[0,:,:] cond=cond[1:,:,:] print("online training, initializing model...") n=cond.shape[0] model=Transform(n=n) optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001) model.to(device) model.train() random.seed(42) bar=tqdm(range(step)) for s in bar: optimizer.zero_grad() x=cond output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() bar.set_postfix(loss=loss.item()) weight=model.weight print(weight) cond=weight[:,:,None]*cond+y[None,:,:]*(1.0/n) print("online training, ending...") del model del optimizer cond=torch.mean(cond,dim=0).unsqueeze(0) return cond.to(old_device,dtype=dtype) ================================================ FILE: pulidflux.py ================================================ import torch from torch import nn, Tensor from torchvision import transforms from torchvision.transforms import functional import os import logging import folder_paths import comfy.utils from comfy.ldm.flux.layers import timestep_embedding import comfy.model_management from insightface.app import FaceAnalysis from facexlib.parsing import init_parsing_model from facexlib.utils.face_restoration_helper import FaceRestoreHelper import torch.nn.functional as F from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .encoders_flux import IDFormer, PerceiverAttentionCA INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface") MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid") if "pulid" not in folder_paths.folder_names_and_paths: current_paths = [MODELS_DIR] else: current_paths, _ = folder_paths.folder_names_and_paths["pulid"] folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions) from .online_train2 import online_train class PulidFluxModel(nn.Module): def __init__(self): super().__init__() self.double_interval = 2 self.single_interval = 4 # Init encoder self.pulid_encoder = IDFormer() # Init attention num_ca = 19 // self.double_interval + 38 // self.single_interval if 19 % self.double_interval != 0: num_ca += 1 if 38 % self.single_interval != 0: num_ca += 1 self.pulid_ca = nn.ModuleList([ PerceiverAttentionCA() for _ in range(num_ca) ]) def from_pretrained(self, path: str): state_dict = comfy.utils.load_torch_file(path, safe_load=True) state_dict_dict = {} for k, v in state_dict.items(): module = k.split('.')[0] state_dict_dict.setdefault(module, {}) new_k = k[len(module) + 1:] state_dict_dict[module][new_k] = v for module in state_dict_dict: getattr(self, module).load_state_dict(state_dict_dict[module], strict=True) del state_dict del state_dict_dict def get_embeds(self, face_embed, clip_embeds): return self.pulid_encoder(face_embed, clip_embeds) def forward_orig( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, y: Tensor, guidance: Tensor = None, control=None, transformer_options={}, attn_mask: Tensor = None, **kwargs # so it won't break if we add more stuff in the future ) -> Tensor: device = comfy.model_management.get_torch_device() patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y) txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) ca_idx = 0 blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attn_mask": attn_mask}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] if add is not None: img += add # PuLID attention if self.pulid_data: if i % self.pulid_double_interval == 0: # Will calculate influence of all pulid nodes at once for _, node_data in self.pulid_data.items(): condition_start = node_data['sigma_start'] >= timesteps condition_end = timesteps >= node_data['sigma_end'] condition = torch.logical_and( condition_start, condition_end).all() if condition: img = img + node_data['weight'] * self.pulid_ca[ca_idx].to(device)(node_data['embedding'], img) ca_idx += 1 img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attn_mask": attn_mask}, {"original_block": block_wrap}) img = out["img"] else: img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: img[:, txt.shape[1] :, ...] += add # PuLID attention if self.pulid_data: real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] if i % self.pulid_single_interval == 0: # Will calculate influence of all nodes at once for _, node_data in self.pulid_data.items(): condition_start = node_data['sigma_start'] >= timesteps condition_end = timesteps >= node_data['sigma_end'] # Combine conditions and reduce to a single boolean condition = torch.logical_and(condition_start, condition_end).all() if condition: real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx].to(device)(node_data['embedding'], real_img) ca_idx += 1 img = torch.cat((txt, real_img), 1) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img def tensor_to_image(tensor): image = tensor.mul(255).clamp(0, 255).byte().cpu() image = image[..., [2, 1, 0]].numpy() return image def image_to_tensor(image): tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) tensor = tensor[..., [2, 1, 0]] return tensor def resize_with_pad(img, target_size): # image: 1, h, w, 3 img = img.permute(0, 3, 1, 2) H, W = target_size h, w = img.shape[2], img.shape[3] scale_h = H / h scale_w = W / w scale = min(scale_h, scale_w) new_h = int(min(h * scale,H)) new_w = int(min(w * scale,W)) new_size = (new_h, new_w) img = F.interpolate(img, size=new_size, mode='bicubic', align_corners=False) pad_top = (H - new_h) // 2 pad_bottom = (H - new_h) - pad_top pad_left = (W - new_w) // 2 pad_right = (W - new_w) - pad_left img = F.pad(img, pad=(pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) return img.permute(0, 2, 3, 1) def to_gray(img): x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] x = x.repeat(1, 3, 1, 1) return x """ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Nodes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ class PulidFluxModelLoader: @classmethod def INPUT_TYPES(s): return {"required": {"pulid_file": (folder_paths.get_filename_list("pulid"), )}} RETURN_TYPES = ("PULIDFLUX",) FUNCTION = "load_model" CATEGORY = "pulid" def load_model(self, pulid_file): model_path = folder_paths.get_full_path("pulid", pulid_file) # Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node model = PulidFluxModel() logging.info("Loading PuLID-Flux model.") model.from_pretrained(path=model_path) return (model,) class PulidFluxInsightFaceLoader: @classmethod def INPUT_TYPES(s): return { "required": { "provider": (["CPU", "CUDA", "ROCM"], ), }, } RETURN_TYPES = ("FACEANALYSIS",) FUNCTION = "load_insightface" CATEGORY = "pulid" def load_insightface(self, provider): model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l model.prepare(ctx_id=0, det_size=(640, 640)) return (model,) class PulidFluxEvaClipLoader: @classmethod def INPUT_TYPES(s): return { "required": {}, } RETURN_TYPES = ("EVA_CLIP",) FUNCTION = "load_eva_clip" CATEGORY = "pulid" def load_eva_clip(self): from .eva_clip.factory import create_model_and_transforms model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) model = model.visual eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN) eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD) if not isinstance(eva_transform_mean, (list, tuple)): model["image_mean"] = (eva_transform_mean,) * 3 if not isinstance(eva_transform_std, (list, tuple)): model["image_std"] = (eva_transform_std,) * 3 return (model,) class ApplyPulidFlux: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL", ), "pulid_flux": ("PULIDFLUX", ), "eva_clip": ("EVA_CLIP", ), "face_analysis": ("FACEANALYSIS", ), "image": ("IMAGE", ), "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), "fusion": (["mean","concat","max","norm_id","max_token","auto_weight","train_weight"],), "fusion_weight_max": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 20.0, "step": 0.1 }), "fusion_weight_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.1 }), "train_step": ("INT", {"default": 1000, "min": 0, "max": 20000, "step": 1 }), "use_gray": ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}), }, "optional": { "attn_mask": ("MASK", ), "prior_image": ("IMAGE",), # for train weight, as the target }, "hidden": { "unique_id": "UNIQUE_ID" }, } RETURN_TYPES = ("MODEL",) FUNCTION = "apply_pulid_flux" CATEGORY = "pulid" def __init__(self): self.pulid_data_dict = None def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, prior_image=None,fusion="mean", fusion_weight_max=1.0, fusion_weight_min=0.0, train_step=1000, use_gray=True, attn_mask=None, unique_id=None): device = comfy.model_management.get_torch_device() # Why should I care what args say, when the unet model has a different dtype?! # Am I missing something?! #dtype = comfy.model_management.unet_dtype() dtype = model.model.diffusion_model.dtype # For 8bit use bfloat16 (because ufunc_add_CUDA is not implemented) if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: dtype = torch.bfloat16 eva_clip.to(device, dtype=dtype) pulid_flux.to(device, dtype=dtype) # TODO: Add masking support! if attn_mask is not None: if attn_mask.dim() > 3: attn_mask = attn_mask.squeeze(-1) elif attn_mask.dim() < 3: attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.to(device, dtype=dtype) if prior_image is not None: prior_image = resize_with_pad(prior_image.to(image.device, dtype=image.dtype), target_size=(image.shape[1], image.shape[2])) image=torch.cat((prior_image,image),dim=0) image = tensor_to_image(image) face_helper = FaceRestoreHelper( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device=device, ) face_helper.face_parse = None face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device) bg_label = [0, 16, 18, 7, 8, 9, 14, 15] cond = [] # Analyse multiple images at multiple sizes and combine largest area embeddings for i in range(image.shape[0]): # get insightface embeddings iface_embeds = None for size in [(size, size) for size in range(640, 256, -64)]: face_analysis.det_model.input_size = size face_info = face_analysis.get(image[i]) if face_info: # Only use the maximum face # Removed the reverse=True from original code because we need the largest area not the smallest one! # Sorts the list in ascending order (smallest to largest), # then selects the last element, which is the largest face face_info = sorted(face_info, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1] iface_embeds = torch.from_numpy(face_info.embedding).unsqueeze(0).to(device, dtype=dtype) break else: # No face detected, skip this image logging.warning(f'Warning: No face detected in image {str(i)}') continue # get eva_clip embeddings face_helper.clean_all() face_helper.read_image(image[i]) face_helper.get_face_landmarks_5(only_center_face=True) face_helper.align_warp_face() if len(face_helper.cropped_faces) == 0: # No face detected, skip this image continue # Get aligned face image align_face = face_helper.cropped_faces[0] # Convert bgr face image to tensor align_face = image_to_tensor(align_face).unsqueeze(0).permute(0, 3, 1, 2).to(device) parsing_out = face_helper.face_parse(functional.normalize(align_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] parsing_out = parsing_out.argmax(dim=1, keepdim=True) bg = sum(parsing_out == i for i in bg_label).bool() white_image = torch.ones_like(align_face) # Only keep the face features if use_gray: _align_face = to_gray(align_face) else: _align_face = align_face face_features_image = torch.where(bg, white_image, _align_face) # Transform img before sending to eva_clip # Apparently MPS only supports NEAREST interpolation? face_features_image = functional.resize(face_features_image, eva_clip.image_size, transforms.InterpolationMode.BICUBIC if 'cuda' in device.type else transforms.InterpolationMode.NEAREST).to(device, dtype=dtype) face_features_image = functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std) # eva_clip id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) id_cond_vit = id_cond_vit.to(device, dtype=dtype) for idx in range(len(id_vit_hidden)): id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype) id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True)) # Combine embeddings id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1) # Pulid_encoder cond.append(pulid_flux.get_embeds(id_cond, id_vit_hidden)) if not cond: # No faces detected, return the original model logging.warning("PuLID warning: No faces detected in any of the given images, returning unmodified model.") return (model,) # fusion embeddings if fusion == "mean": cond = torch.cat(cond).to(device, dtype=dtype) # N,32,2048 if cond.shape[0] > 1: cond = torch.mean(cond, dim=0, keepdim=True) elif fusion == "concat": cond = torch.cat(cond, dim=1).to(device, dtype=dtype) elif fusion == "max": cond = torch.cat(cond).to(device, dtype=dtype) if cond.shape[0] > 1: cond = torch.max(cond, dim=0, keepdim=True)[0] elif fusion == "norm_id": cond = torch.cat(cond).to(device, dtype=dtype) if cond.shape[0] > 1: norm=torch.norm(cond,dim=(1,2)) norm=norm/torch.sum(norm) cond=torch.einsum("wij,w->ij",cond,norm).unsqueeze(0) elif fusion == "max_token": cond = torch.cat(cond).to(device, dtype=dtype) if cond.shape[0] > 1: norm=torch.norm(cond,dim=2) _,idx=torch.max(norm,dim=0) cond=torch.stack([cond[j,i] for i,j in enumerate(idx)]).unsqueeze(0) elif fusion == "auto_weight": # 🤔 cond = torch.cat(cond).to(device, dtype=dtype) if cond.shape[0] > 1: norm=torch.norm(cond,dim=2) order=torch.argsort(norm,descending=False,dim=0) regular_weight=torch.linspace(fusion_weight_min,fusion_weight_max,norm.shape[0]).to(device, dtype=dtype) _cond=[] for i in range(cond.shape[1]): o=order[:,i] _cond.append(torch.einsum('ij,i->j',cond[:,i,:],regular_weight[o])) cond=torch.stack(_cond,dim=0).unsqueeze(0) elif fusion == "train_weight": cond = torch.cat(cond).to(device, dtype=dtype) if cond.shape[0] > 1: if train_step > 0: with torch.inference_mode(False): cond = online_train(cond, device=cond.device, step=train_step) else: cond = torch.mean(cond, dim=0, keepdim=True) sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at) sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at) # Patch the Flux model (original diffusion_model) # Nah, I don't care for the official ModelPatcher because it's undocumented! # I want the end result now, and I don’t mind if I break other custom nodes in the process. 😄 flux_model = model.model.diffusion_model # Let's see if we already patched the underlying flux model, if not apply patch if not hasattr(flux_model, "pulid_ca"): # Add perceiver attention, variables and current node data (weight, embedding, sigma_start, sigma_end) # The pulid_data is stored in Dict by unique node index, # so we can chain multiple ApplyPulidFlux nodes! flux_model.pulid_ca = pulid_flux.pulid_ca flux_model.pulid_double_interval = pulid_flux.double_interval flux_model.pulid_single_interval = pulid_flux.single_interval flux_model.pulid_data = {} # Replace model forward_orig with our own new_method = forward_orig.__get__(flux_model, flux_model.__class__) setattr(flux_model, 'forward_orig', new_method) # Patch is already in place, add data (weight, embedding, sigma_start, sigma_end) under unique node index flux_model.pulid_data[unique_id] = { 'weight': weight, 'embedding': cond, 'sigma_start': sigma_start, 'sigma_end': sigma_end, } # Keep a reference for destructor (if node is deleted the data will be deleted as well) self.pulid_data_dict = {'data': flux_model.pulid_data, 'unique_id': unique_id} return (model,) def __del__(self): # Destroy the data for this node if self.pulid_data_dict: del self.pulid_data_dict['data'][self.pulid_data_dict['unique_id']] del self.pulid_data_dict NODE_CLASS_MAPPINGS = { "PulidFluxModelLoader": PulidFluxModelLoader, "PulidFluxInsightFaceLoader": PulidFluxInsightFaceLoader, "PulidFluxEvaClipLoader": PulidFluxEvaClipLoader, "ApplyPulidFlux": ApplyPulidFlux, } NODE_DISPLAY_NAME_MAPPINGS = { "PulidFluxModelLoader": "Load PuLID Flux Model", "PulidFluxInsightFaceLoader": "Load InsightFace (PuLID Flux)", "PulidFluxEvaClipLoader": "Load Eva Clip (PuLID Flux)", "ApplyPulidFlux": "Apply PuLID Flux", } ================================================ FILE: requirements.txt ================================================ facexlib insightface onnxruntime onnxruntime-gpu ftfy timm