Repository: SeaArtLab/ComfyUI-Long-CLIP
Branch: main
Commit: 889e552b1280
Files: 10
Total size: 90.8 KB
Directory structure:
gitextract_dj6dn1g3/
├── README-CN.md
├── README.md
├── __init__.py
├── long_clip.py
├── long_clip_model/
│ ├── longclip.py
│ ├── model_longclip.py
│ └── simple_tokenizer.py
└── workflow/
├── flux-long.json
├── sd1-5-long.json
└── sdxl-long.json
================================================
FILE CONTENTS
================================================
================================================
FILE: README-CN.md
================================================
# ComfyUI-Long-CLIP
本项目是long-clip的comfyui实现,目前支持clip-l的替换,对于SD1.5可以使用SeaArtLongClip模块加载后替换模型中原本的clip,token的长度由77扩大至248,经过测试我们发现long-clip对成图质量有提升作用,对于SDXL模型由于clip-g的clip-long模型没有出现,所以我们的处理流程如下,对于较小的token按照原本max_len的整数倍扩大,由于最后多出来的为pad_token,所以对多余部分进行了裁剪,由于SDXL中clip-g的特征占比更多,您可能看见的更多是画面出现更多细节,最后喜欢本项目请帮我们点赞
## Start
```
git clone https://github.com/comfyanonymous/ComfyUI.git
cd ComfyUI/custom_nodes
git clone git@github.com:SeaArtLab/ComfyUI-Long-CLIP.git
```
下载[LongCLIP-L](https://huggingface.co/BeichenZhang/LongCLIP-L)到 models/checkpoints,同时感谢[Long-CLIP](https://github.com/beichenzbc/Long-CLIP/tree/main)开放权重,后续LongCLIP-G权重开放后,我们会同步支持!
## Workflow
我们特意准备了SD1.5和SDXL的使用例子,为了简化演示,我们例子简单,您无需安装其他插件


================================================
FILE: README.md
================================================
# ComfyUI-Long-CLIP (Flux Suport Now)
This project implements the comfyui for long-clip, currently supporting the replacement of clip-l. For SD1.5, the SeaArtLongClip module can be used to replace the original clip in the model, expanding the token length from 77 to 248. Through testing, we found that long-clip improves the quality of the generated images. As for the SDXL model, since the clip-long model for clip-g has not been released, our processing procedure is as follows: for smaller tokens, we expand them by an integer multiple of the original max_len, and since the last added are pad_tokens, we trim the excess part. Given that clip-g features occupy a larger proportion in SDXL, you may notice more detailed images. Finally, if you like our project, please give us a thumbs up.
Thanks to [zer0int](https://github.com/zer0int)'s work, now long-clip provides support for flux.
## Start
```
git clone https://github.com/comfyanonymous/ComfyUI.git
cd ComfyUI/custom_nodes
git clone git@github.com:SeaArtLab/ComfyUI-Long-CLIP.git
```
Download [LongCLIP-L](https://huggingface.co/BeichenZhang/LongCLIP-L) to models/checkpoints, and thanks to [Long-CLIP](https://github.com/beichenzbc/Long-CLIP/tree/main) for making the weights available. Once the LongCLIP-G weights are released, we will also support them!
## Workflow
We have specifically prepared examples for SD1.5 and SDXL for your use. To simplify the demonstration, our examples are straightforward, and you do not need to install any additional plugins. This plugin also supports operations such as clip-skip.



================================================
FILE: __init__.py
================================================
from . import long_clip as long_clip
NODE_CLASS_MAPPINGS = {
"SeaArtLongClip": long_clip.SeaArtLongClip,
"SeaArtLongXLClipMerge": long_clip.SeaArtLongXLClipMerge,
"LongCLIPTextEncodeFlux": long_clip.LongCLIPTextEncodeFlux,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SeaArtLongClip": "SeaArtLongClip",
"SeaArtLongXLClipMerge": "SeaArtLongXLClipMerge",
"LongCLIPTextEncodeFlux": "LongCLIPTextEncodeFlux",
}
================================================
FILE: long_clip.py
================================================
from .long_clip_model import longclip
import os
import torch
from comfy.sd import CLIP
import folder_paths
from comfy.sd1_clip import load_embed,ClipTokenWeightEncoder
from comfy.model_management import get_torch_device
from comfy import model_management
import comfy
class SDLongClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True, **kwargs): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.transformer, _ = longclip.load(version, device=device)
self.num_layers = self.transformer.transformer_layers
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
self.dtypes = [param.dtype for param in self.parameters()]
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def clip_layer(self, layer_idx):
if abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def reset_clip_layer(self):
self.layer = self.layer_default[0]
self.layer_idx = self.layer_default[1]
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def reset_clip_options(self):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, int):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [y]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding)
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False)
class SDLongTokenizer:
def __init__(self, max_length=248, pad_with_end=True, embedding_directory=None, tokenizer_data=None, embedding_size=768, embedding_key='clip_l', has_start_token=True, pad_to_max_length=True):
self.tokenizer = longclip.only_tokenize ##tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
empty = self.tokenizer('')[0]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
##vocab = self.tokenizer.get_vocab()
##self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
self.embedding_key = embedding_key
self.tokenizer_data = tokenizer_data
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
from comfy.sd1_clip import token_weights,escape_important,unescape_important
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)[0][self.tokens_start:-1]])
#reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
#fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
def pad_tokens(tokens,clip,add_token_num):
if clip.pad_with_end:
pad_token = clip.end_token
else:
pad_token = 0
while add_token_num > 0:
batch = []
batch.append((clip.end_token, 1.0, 0))
add_pad = clip.max_length - 1
batch.extend([(pad_token, 1.0, 0)] * add_pad)
tokens.append(batch)
add_token_num -= (add_pad+1)
return tokens
def token_num(tokens):
n = 0
for token in tokens:
n += len(token)
return n
class SDXLLongClipModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.clip_l = None
self.clip_g = None
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
def reset_clip_options(self):
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
g_tokens = g_out.shape[1]
l_tokens = l_out.shape[1]
min_tokens = min(g_tokens,l_tokens)
g_out = g_out[:,:min_tokens,:]
l_out = l_out[:,:min_tokens,:]
return torch.cat([l_out, g_out], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
else:
return self.clip_l.load_sd(sd)
class SDXLLongTokenizer:
def __init__(self):
self.clip_l = None
self.clip_g = None
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
g_tokens = token_num(out["g"])
l_tokens = token_num(out["l"])
if g_tokens > l_tokens:
out["l"] = pad_tokens(out["l"],self.clip_l,g_tokens-l_tokens)
elif l_tokens > g_tokens:
out["g"] = pad_tokens(out["g"],self.clip_g,l_tokens-g_tokens)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
class LongCLIPFluxModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.clip_l = None
self.t5xxl = None
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.t5xxl.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
# Encode using Long-CLIP
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
# Encode using T5XXL
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
return t5_out, l_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
class LongCLIPFluxTokenizer:
def __init__(self):
self.clip_l = None
self.t5xxl = None
def tokenize_with_weights(self, text: str, return_word_ids=False):
# Tokenize with both Long-CLIP and T5XXL
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) # Long-CLIP tokenization
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) # T5XXL tokenization
# Check the number of tokens
l_tokens = token_num(out["l"])
t5_tokens = token_num(out["t5xxl"])
# Leaving this here as a reminder: Do NOT pad T5XXL!
if l_tokens > t5_tokens:
pass # Do not pad T5XXL
return out
def untokenize(self, token_weight_pair):
# Untokenize using Long-CLIP tokenizer
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
class SeaArtLongXLClipMerge:
@classmethod
def INPUT_TYPES(cls):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"clip": ("CLIP", ),
}}
CATEGORY = "SeaArt"
RETURN_TYPES = ("CLIP",)
FUNCTION = "do"
def do(self, clip_name, clip):
clip_clone = clip.clone()
clip_path = folder_paths.get_full_path("clip", clip_name)
load_device = model_management.text_encoder_device()
device = model_management.text_encoder_offload_device()
dtype = model_management.text_encoder_dtype(load_device)
clip_l = SDLongClipModel(version=clip_path,layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
sdxl_long_clip_model = SDXLLongClipModel()
sdxl_long_clip_model.clip_l = clip_l
sdxl_long_clip_model.clip_g = clip_clone.cond_stage_model.clip_g
clip_clone.cond_stage_model = sdxl_long_clip_model
embedding_directory = folder_paths.get_folder_paths("embeddings")
long_tokenizer = SDXLLongTokenizer()
tokenizer_clip_l = SDLongTokenizer(embedding_directory=embedding_directory)
long_tokenizer.clip_l = tokenizer_clip_l
long_tokenizer.clip_g = clip_clone.tokenizer.clip_g
clip_clone.tokenizer = long_tokenizer
return (clip_clone,)
class SeaArtLongClip:
@classmethod
def INPUT_TYPES(cls):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
}}
CATEGORY = "SeaArt"
RETURN_TYPES = ("CLIP",)
FUNCTION = "do"
def do(self, clip_name):
class EmptyClass:
pass
clip_target = EmptyClass()
clip_path = folder_paths.get_full_path("clip", clip_name)
clip_target.params = {"version":clip_path}
clip_target.clip = SDLongClipModel
clip_target.tokenizer = SDLongTokenizer
embedding_directory = folder_paths.get_folder_paths("embeddings")
clip = CLIP(clip_target, embedding_directory=embedding_directory)
return (clip,)
class LongCLIPTextEncodeFlux:
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"clip_name": (folder_paths.get_filename_list("clip"), ),
"clip": ("CLIP", ),
}}
CATEGORY = "SeaArt"
RETURN_TYPES = ("CLIP",)
FUNCTION = "do"
def do(self, clip_name, clip):
clip_clone = clip.clone()
clip_path = folder_paths.get_full_path("clip", clip_name)
load_device = model_management.text_encoder_device()
device = model_management.text_encoder_offload_device()
dtype = model_management.text_encoder_dtype(load_device)
longclip_model = SDLongClipModel(version=clip_path, layer="hidden", layer_idx=-2, device=device, dtype=dtype, max_length=248)
flux_clip_model = LongCLIPFluxModel()
flux_clip_model.clip_l = longclip_model
flux_clip_model.t5xxl = clip_clone.cond_stage_model.t5xxl
clip_clone.cond_stage_model = flux_clip_model
long_tokenizer = LongCLIPFluxTokenizer()
long_tokenizer.clip_l = SDLongTokenizer(embedding_directory=clip_clone.tokenizer.clip_l.embedding_directory, max_length=248)
long_tokenizer.t5xxl = clip_clone.tokenizer.t5xxl
clip_clone.tokenizer = long_tokenizer
return (clip_clone,)
================================================
FILE: long_clip_model/longclip.py
================================================
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from torch import nn
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from safetensors.torch import load_file
from .model_longclip import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
import packaging
except ImportError:
from pkg_resources import packaging
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["load", "tokenize"]
_tokenizer = _Tokenizer()
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", download_root: str = None):
"""Load a long CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
model_path = name
if model_path.endswith(".safetensors"):
state_dict = load_file(model_path, device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
model = build_model(state_dict or model.state_dict(), load_from_clip=False).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628
"""
sel = node.kindOf(key)
return getattr(node, sel)(key)
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if _node_get(inputs[i].node(), "value") == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def load_from_clip(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load from CLIP model for fine-tuning
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
model = build_model(state_dict or model.state_dict(), load_from_clip = True).to(device)
positional_embedding_pre = model.positional_embedding.type(model.dtype)
length, dim = positional_embedding_pre.shape
keep_len = 20
posisitonal_embedding_new = torch.zeros([4*length-3*keep_len, dim], dtype=model.dtype)
for i in range(keep_len):
posisitonal_embedding_new[i] = positional_embedding_pre[i]
for i in range(length-1-keep_len):
posisitonal_embedding_new[4*i + keep_len] = positional_embedding_pre[i + keep_len]
posisitonal_embedding_new[4*i + 1 + keep_len] = 3*positional_embedding_pre[i + keep_len]/4 + 1*positional_embedding_pre[i+1+keep_len]/4
posisitonal_embedding_new[4*i + 2+keep_len] = 2*positional_embedding_pre[i+keep_len]/4 + 2*positional_embedding_pre[i+1+keep_len]/4
posisitonal_embedding_new[4*i + 3+keep_len] = 1*positional_embedding_pre[i+keep_len]/4 + 3*positional_embedding_pre[i+1+keep_len]/4
posisitonal_embedding_new[4*length -3*keep_len - 4] = positional_embedding_pre[length-1] + 0*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
posisitonal_embedding_new[4*length -3*keep_len - 3] = positional_embedding_pre[length-1] + 1*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
posisitonal_embedding_new[4*length -3*keep_len - 2] = positional_embedding_pre[length-1] + 2*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
posisitonal_embedding_new[4*length -3*keep_len - 1] = positional_embedding_pre[length-1] + 3*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
positional_embedding_res = posisitonal_embedding_new.clone()
model.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False)
model.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628
"""
sel = node.kindOf(key)
return getattr(node, sel)(key)
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if _node_get(inputs[i].node(), "value") == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def only_tokenize(texts: Union[str, List[str]]) -> Union[torch.IntTensor, torch.LongTensor]:
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
return all_tokens
def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
================================================
FILE: long_clip_model/model_longclip.py
================================================
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor, intermediate_output=None, attn_mask: torch.Tensor = None):
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = self.layers + intermediate_output
intermediate = None
for i, l in enumerate(self.resblocks):
l.attn_mask = attn_mask
x = l(x)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
load_from_clip: bool
):
super().__init__()
self.context_length = 248
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer_layers = transformer_layers
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
if load_from_clip == False:
self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))
else:
self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
self.mask1 = torch.zeros([248, 1])
self.mask1[:20, :] = 1
self.mask2 = torch.zeros([248, 1])
self.mask2[20:, :] = 1
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def get_input_embeddings(self):
return self.token_embedding
def set_input_embeddings(self, embeddings):
self.token_embedding = embeddings
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def encode_text_full(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
#x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.token_embedding(input_tokens).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x = x.permute(1, 0, 2) # NLD -> LND
x,i = self.transformer(x,intermediate_output=intermediate_output, attn_mask= mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
if i is not None and final_layer_norm_intermediate:
i = self.ln_final(i).type(self.dtype)
if i is not None:
i = i.permute(1, 0, 2) # LND -> NLD
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x,i,pooled_output
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict, load_from_clip: bool):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, load_from_clip
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()
================================================
FILE: long_clip_model/simple_tokenizer.py
================================================
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
================================================
FILE: workflow/flux-long.json
================================================
{
"last_node_id": 38,
"last_link_id": 118,
"nodes": [
{
"id": 17,
"type": "BasicScheduler",
"pos": [
480,
1008
],
"size": {
"0": 315,
"1": 106
},
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 55,
"slot_index": 0
}
],
"outputs": [
{
"name": "SIGMAS",
"type": "SIGMAS",
"links": [
20
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "BasicScheduler"
},
"widgets_values": [
"simple",
20,
1
]
},
{
"id": 16,
"type": "KSamplerSelect",
"pos": [
480,
912
],
"size": {
"0": 315,
"1": 58
},
"flags": {},
"order": 0,
"mode": 0,
"outputs": [
{
"name": "SAMPLER",
"type": "SAMPLER",
"links": [
19
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "KSamplerSelect"
},
"widgets_values": [
"euler"
]
},
{
"id": 26,
"type": "FluxGuidance",
"pos": [
480,
144
],
"size": {
"0": 317.4000244140625,
"1": 58
},
"flags": {},
"order": 14,
"mode": 0,
"inputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 41
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
42
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "FluxGuidance"
},
"widgets_values": [
3.5
],
"color": "#233",
"bgcolor": "#355"
},
{
"id": 22,
"type": "BasicGuider",
"pos": [
576,
48
],
"size": {
"0": 222.3482666015625,
"1": 46
},
"flags": {},
"order": 15,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 54,
"slot_index": 0
},
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 42,
"slot_index": 1
}
],
"outputs": [
{
"name": "GUIDER",
"type": "GUIDER",
"links": [
30
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "BasicGuider"
}
},
{
"id": 13,
"type": "SamplerCustomAdvanced",
"pos": [
864,
192
],
"size": {
"0": 272.3617858886719,
"1": 124.53733825683594
},
"flags": {},
"order": 16,
"mode": 0,
"inputs": [
{
"name": "noise",
"type": "NOISE",
"link": 37,
"slot_index": 0
},
{
"name": "guider",
"type": "GUIDER",
"link": 30,
"slot_index": 1
},
{
"name": "sampler",
"type": "SAMPLER",
"link": 19,
"slot_index": 2
},
{
"name": "sigmas",
"type": "SIGMAS",
"link": 20,
"slot_index": 3
},
{
"name": "latent_image",
"type": "LATENT",
"link": 116,
"slot_index": 4
}
],
"outputs": [
{
"name": "output",
"type": "LATENT",
"links": [
24
],
"shape": 3,
"slot_index": 0
},
{
"name": "denoised_output",
"type": "LATENT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "SamplerCustomAdvanced"
}
},
{
"id": 25,
"type": "RandomNoise",
"pos": [
480,
768
],
"size": {
"0": 315,
"1": 82
},
"flags": {},
"order": 1,
"mode": 0,
"outputs": [
{
"name": "NOISE",
"type": "NOISE",
"links": [
37
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "RandomNoise"
},
"widgets_values": [
219670278747233,
"randomize"
],
"color": "#2a363b",
"bgcolor": "#3f5159"
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
866,
367
],
"size": {
"0": 210,
"1": 46
},
"flags": {},
"order": 17,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 24
},
{
"name": "vae",
"type": "VAE",
"link": 12
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
9
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAEDecode"
}
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
384,
240
],
"size": {
"0": 422.84503173828125,
"1": 164.31304931640625
},
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 118
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
41
],
"slot_index": 0
}
],
"title": "CLIP Text Encode (Positive Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere"
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 30,
"type": "ModelSamplingFlux",
"pos": [
480,
1152
],
"size": {
"0": 315,
"1": 130
},
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 56,
"slot_index": 0
},
{
"name": "width",
"type": "INT",
"link": 115,
"widget": {
"name": "width"
},
"slot_index": 1
},
{
"name": "height",
"type": "INT",
"link": 114,
"widget": {
"name": "height"
},
"slot_index": 2
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
54,
55
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "ModelSamplingFlux"
},
"widgets_values": [
1.15,
0.5,
1024,
1024
]
},
{
"id": 27,
"type": "EmptySD3LatentImage",
"pos": [
480,
624
],
"size": {
"0": 315,
"1": 106
},
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "width",
"type": "INT",
"link": 112,
"widget": {
"name": "width"
}
},
{
"name": "height",
"type": "INT",
"link": 113,
"widget": {
"name": "height"
}
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
116
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "EmptySD3LatentImage"
},
"widgets_values": [
1024,
1024,
1
]
},
{
"id": 34,
"type": "PrimitiveNode",
"pos": [
432,
480
],
"size": {
"0": 210,
"1": 82
},
"flags": {},
"order": 2,
"mode": 0,
"outputs": [
{
"name": "INT",
"type": "INT",
"links": [
112,
115
],
"slot_index": 0,
"widget": {
"name": "width"
}
}
],
"title": "width",
"properties": {
"Run widget replace on values": false
},
"widgets_values": [
1024,
"fixed"
],
"color": "#323",
"bgcolor": "#535"
},
{
"id": 35,
"type": "PrimitiveNode",
"pos": [
672,
480
],
"size": {
"0": 210,
"1": 82
},
"flags": {},
"order": 3,
"mode": 0,
"outputs": [
{
"name": "INT",
"type": "INT",
"links": [
113,
114
],
"widget": {
"name": "height"
},
"slot_index": 0
}
],
"title": "height",
"properties": {
"Run widget replace on values": false
},
"widgets_values": [
1024,
"fixed"
],
"color": "#323",
"bgcolor": "#535"
},
{
"id": 9,
"type": "SaveImage",
"pos": [
1155,
196
],
"size": {
"0": 985.3012084960938,
"1": 1060.3828125
},
"flags": {},
"order": 18,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 9
}
],
"properties": {},
"widgets_values": [
"ComfyUI"
]
},
{
"id": 37,
"type": "Note",
"pos": [
480,
1344
],
"size": {
"0": 314.99755859375,
"1": 117.98363494873047
},
"flags": {},
"order": 4,
"mode": 0,
"properties": {
"text": ""
},
"widgets_values": [
"The reference sampling implementation auto adjusts the shift value based on the resolution, if you don't want this you can just bypass (CTRL-B) this ModelSamplingFlux node.\n"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 10,
"type": "VAELoader",
"pos": [
48,
432
],
"size": {
"0": 311.81634521484375,
"1": 60.429901123046875
},
"flags": {},
"order": 5,
"mode": 0,
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
12
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAELoader"
},
"widgets_values": [
"ae.safetensors"
]
},
{
"id": 28,
"type": "Note",
"pos": [
48,
576
],
"size": {
"0": 336,
"1": 288
},
"flags": {},
"order": 6,
"mode": 0,
"properties": {
"text": ""
},
"widgets_values": [
"If you get an error in any of the nodes above make sure the files are in the correct directories.\n\nSee the top of the examples page for the links : https://comfyanonymous.github.io/ComfyUI_examples/flux/\n\nflux1-dev.safetensors goes in: ComfyUI/models/unet/\n\nt5xxl_fp16.safetensors and clip_l.safetensors go in: ComfyUI/models/clip/\n\nae.safetensors goes in: ComfyUI/models/vae/\n\n\nTip: You can set the weight_dtype above to one of the fp8 types if you have memory issues."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 11,
"type": "DualCLIPLoader",
"pos": [
48,
275
],
"size": {
"0": 315,
"1": 106
},
"flags": {},
"order": 7,
"mode": 0,
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
117
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "DualCLIPLoader"
},
"widgets_values": [
"t5xxl_fp16.safetensors",
"clip_l.safetensors",
"flux"
]
},
{
"id": 38,
"type": "LongCLIPTextEncodeFlux",
"pos": [
56,
174
],
"size": [
305.3363800289681,
58
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 117
}
],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
118
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "LongCLIPTextEncodeFlux"
},
"widgets_values": [
"longclip-L.pt"
]
},
{
"id": 12,
"type": "UNETLoader",
"pos": [
61,
46
],
"size": {
"0": 315,
"1": 82
},
"flags": {},
"order": 8,
"mode": 0,
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
56
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "UNETLoader"
},
"widgets_values": [
"flux1-dev.safetensors",
"default"
],
"color": "#223",
"bgcolor": "#335"
}
],
"links": [
[
9,
8,
0,
9,
0,
"IMAGE"
],
[
12,
10,
0,
8,
1,
"VAE"
],
[
19,
16,
0,
13,
2,
"SAMPLER"
],
[
20,
17,
0,
13,
3,
"SIGMAS"
],
[
24,
13,
0,
8,
0,
"LATENT"
],
[
30,
22,
0,
13,
1,
"GUIDER"
],
[
37,
25,
0,
13,
0,
"NOISE"
],
[
41,
6,
0,
26,
0,
"CONDITIONING"
],
[
42,
26,
0,
22,
1,
"CONDITIONING"
],
[
54,
30,
0,
22,
0,
"MODEL"
],
[
55,
30,
0,
17,
0,
"MODEL"
],
[
56,
12,
0,
30,
0,
"MODEL"
],
[
112,
34,
0,
27,
0,
"INT"
],
[
113,
35,
0,
27,
1,
"INT"
],
[
114,
35,
0,
30,
2,
"INT"
],
[
115,
34,
0,
30,
1,
"INT"
],
[
116,
27,
0,
13,
4,
"LATENT"
],
[
117,
11,
0,
38,
0,
"CLIP"
],
[
118,
38,
0,
6,
0,
"CLIP"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1.1,
"offset": [
484.9445814412023,
165.9252922919127
]
},
"groupNodes": {}
},
"version": 0.4
}
================================================
FILE: workflow/sd1-5-long.json
================================================
{
"last_node_id": 10,
"last_link_id": 19,
"nodes": [
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
413,
389
],
"size": {
"0": 425.27801513671875,
"1": 180.6060791015625
},
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 19
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
6
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"text, watermark"
]
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
1209,
188
],
"size": {
"0": 210,
"1": 46
},
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 7
},
{
"name": "vae",
"type": "VAE",
"link": 8
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
9
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAEDecode"
}
},
{
"id": 9,
"type": "SaveImage",
"pos": [
1451,
189
],
"size": [
210,
270
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 9
}
],
"properties": {},
"widgets_values": [
"ComfyUI"
]
},
{
"id": 5,
"type": "EmptyLatentImage",
"pos": [
473,
609
],
"size": {
"0": 315,
"1": 106
},
"flags": {},
"order": 0,
"mode": 0,
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
2
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "EmptyLatentImage"
},
"widgets_values": [
1024,
1024,
1
]
},
{
"id": 3,
"type": "KSampler",
"pos": [
863,
186
],
"size": {
"0": 315,
"1": 262
},
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 1
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 4
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 6
},
{
"name": "latent_image",
"type": "LATENT",
"link": 2
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
7
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
425533035839016,
"fixed",
20,
7,
"euler_ancestral",
"normal",
1
]
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
415,
186
],
"size": {
"0": 422.84503173828125,
"1": 164.31304931640625
},
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 18
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
4
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"Studio Ghibli ~ Hayao Miyazaki ~ Spirited Away ~ beautiful ultra detailed illustration of a cute witch of the sitting on a tree stump reading a book, her ornate robe reminiscent of the stars in the night sky. This contrast between the fantastical character and the more bold color scheme and elements gives the piece an intriguing narrative quality. painted realism, photorealistic, 8k, fantasy digital art, HDR, UHD."
]
},
{
"id": 4,
"type": "CheckpointLoaderSimple",
"pos": [
-252,
425
],
"size": {
"0": 315,
"1": 98
},
"flags": {},
"order": 1,
"mode": 0,
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
1
],
"slot_index": 0
},
{
"name": "CLIP",
"type": "CLIP",
"links": [],
"slot_index": 1
},
{
"name": "VAE",
"type": "VAE",
"links": [
8
],
"slot_index": 2
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple"
},
"widgets_values": [
"Dark_Sushi_Mix.safetensors"
]
},
{
"id": 10,
"type": "SeaArtLongClip",
"pos": [
-248,
204
],
"size": {
"0": 315,
"1": 58
},
"flags": {},
"order": 2,
"mode": 0,
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
18,
19
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SeaArtLongClip"
},
"widgets_values": [
"longclip-L.pt"
]
}
],
"links": [
[
1,
4,
0,
3,
0,
"MODEL"
],
[
2,
5,
0,
3,
3,
"LATENT"
],
[
4,
6,
0,
3,
1,
"CONDITIONING"
],
[
6,
7,
0,
3,
2,
"CONDITIONING"
],
[
7,
3,
0,
8,
0,
"LATENT"
],
[
8,
4,
2,
8,
1,
"VAE"
],
[
9,
8,
0,
9,
0,
"IMAGE"
],
[
18,
10,
0,
6,
0,
"CLIP"
],
[
19,
10,
0,
7,
0,
"CLIP"
]
],
"groups": [],
"config": {},
"extra": {},
"version": 0.4
}
================================================
FILE: workflow/sdxl-long.json
================================================
{
"last_node_id": 13,
"last_link_id": 29,
"nodes": [
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
413,
389
],
"size": {
"0": 425.27801513671875,
"1": 180.6060791015625
},
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 29
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
6
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"text, watermark"
]
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
1209,
188
],
"size": {
"0": 210,
"1": 46
},
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 7
},
{
"name": "vae",
"type": "VAE",
"link": 8
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
9
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAEDecode"
}
},
{
"id": 9,
"type": "SaveImage",
"pos": [
1451,
189
],
"size": {
"0": 210,
"1": 270
},
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 9
}
],
"properties": {},
"widgets_values": [
"ComfyUI"
]
},
{
"id": 5,
"type": "EmptyLatentImage",
"pos": [
473,
609
],
"size": {
"0": 315,
"1": 106
},
"flags": {},
"order": 0,
"mode": 0,
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
2
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "EmptyLatentImage"
},
"widgets_values": [
1024,
1024,
1
]
},
{
"id": 3,
"type": "KSampler",
"pos": [
863,
186
],
"size": {
"0": 315,
"1": 262
},
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 1
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 4
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 6
},
{
"name": "latent_image",
"type": "LATENT",
"link": 2
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
7
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
425533035839016,
"fixed",
20,
7,
"euler_ancestral",
"normal",
1
]
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
415,
186
],
"size": {
"0": 422.84503173828125,
"1": 164.31304931640625
},
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 28
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
4
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"Studio Ghibli ~ Hayao Miyazaki ~ Spirited Away ~ beautiful ultra detailed illustration of a cute witch of the sitting on a tree stump reading a book, her ornate robe reminiscent of the stars in the night sky. This contrast between the fantastical character and the more bold color scheme and elements gives the piece an intriguing narrative quality. painted realism, photorealistic, 8k, fantasy digital art, HDR, UHD."
]
},
{
"id": 4,
"type": "CheckpointLoaderSimple",
"pos": [
-387,
375
],
"size": {
"0": 315,
"1": 98
},
"flags": {},
"order": 1,
"mode": 0,
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
1
],
"slot_index": 0
},
{
"name": "CLIP",
"type": "CLIP",
"links": [
27
],
"slot_index": 1
},
{
"name": "VAE",
"type": "VAE",
"links": [
8
],
"slot_index": 2
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple"
},
"widgets_values": [
"sdxl.safetensors"
]
},
{
"id": 13,
"type": "SeaArtLongXLClipMerge",
"pos": [
-4,
217
],
"size": {
"0": 315,
"1": 58
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 27
}
],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
28,
29
],
"shape": 3,
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SeaArtLongXLClipMerge"
},
"widgets_values": [
"longclip-L.pt"
]
}
],
"links": [
[
1,
4,
0,
3,
0,
"MODEL"
],
[
2,
5,
0,
3,
3,
"LATENT"
],
[
4,
6,
0,
3,
1,
"CONDITIONING"
],
[
6,
7,
0,
3,
2,
"CONDITIONING"
],
[
7,
3,
0,
8,
0,
"LATENT"
],
[
8,
4,
2,
8,
1,
"VAE"
],
[
9,
8,
0,
9,
0,
"IMAGE"
],
[
27,
4,
1,
13,
0,
"CLIP"
],
[
28,
13,
0,
6,
0,
"CLIP"
],
[
29,
13,
0,
7,
0,
"CLIP"
]
],
"groups": [],
"config": {},
"extra": {},
"version": 0.4
}
gitextract_dj6dn1g3/
├── README-CN.md
├── README.md
├── __init__.py
├── long_clip.py
├── long_clip_model/
│ ├── longclip.py
│ ├── model_longclip.py
│ └── simple_tokenizer.py
└── workflow/
├── flux-long.json
├── sd1-5-long.json
└── sdxl-long.json
SYMBOL INDEX (101 symbols across 4 files)
FILE: long_clip.py
class SDLongClipModel (line 12) | class SDLongClipModel(torch.nn.Module, ClipTokenWeightEncoder):
method __init__ (line 19) | def __init__(self, version="openai/clip-vit-large-patch14", device="cp...
method freeze (line 51) | def freeze(self):
method clip_layer (line 57) | def clip_layer(self, layer_idx):
method reset_clip_layer (line 64) | def reset_clip_layer(self):
method set_clip_options (line 68) | def set_clip_options(self, options):
method reset_clip_options (line 77) | def reset_clip_options(self):
method set_up_textual_embeddings (line 82) | def set_up_textual_embeddings(self, tokens, current_embeds):
method forward (line 121) | def forward(self, tokens):
method encode (line 154) | def encode(self, tokens):
method load_sd (line 157) | def load_sd(self, sd):
class SDLongTokenizer (line 164) | class SDLongTokenizer:
method __init__ (line 165) | def __init__(self, max_length=248, pad_with_end=True, embedding_direct...
method _try_get_embedding (line 189) | def _try_get_embedding(self, embedding_name:str):
method tokenize_with_weights (line 203) | def tokenize_with_weights(self, text:str, return_word_ids=False):
method untokenize (line 286) | def untokenize(self, token_weight_pair):
function pad_tokens (line 289) | def pad_tokens(tokens,clip,add_token_num):
function token_num (line 303) | def token_num(tokens):
class SDXLLongClipModel (line 309) | class SDXLLongClipModel(torch.nn.Module):
method __init__ (line 310) | def __init__(self):
method set_clip_options (line 315) | def set_clip_options(self, options):
method reset_clip_options (line 319) | def reset_clip_options(self):
method encode_token_weights (line 323) | def encode_token_weights(self, token_weight_pairs):
method load_sd (line 335) | def load_sd(self, sd):
class SDXLLongTokenizer (line 341) | class SDXLLongTokenizer:
method __init__ (line 342) | def __init__(self):
method tokenize_with_weights (line 346) | def tokenize_with_weights(self, text:str, return_word_ids=False):
method untokenize (line 358) | def untokenize(self, token_weight_pair):
class LongCLIPFluxModel (line 361) | class LongCLIPFluxModel(torch.nn.Module):
method __init__ (line 362) | def __init__(self):
method set_clip_options (line 367) | def set_clip_options(self, options):
method reset_clip_options (line 371) | def reset_clip_options(self):
method encode_token_weights (line 375) | def encode_token_weights(self, token_weight_pairs):
method load_sd (line 386) | def load_sd(self, sd):
class LongCLIPFluxTokenizer (line 392) | class LongCLIPFluxTokenizer:
method __init__ (line 393) | def __init__(self):
method tokenize_with_weights (line 397) | def tokenize_with_weights(self, text: str, return_word_ids=False):
method untokenize (line 413) | def untokenize(self, token_weight_pair):
method state_dict (line 417) | def state_dict(self):
class SeaArtLongXLClipMerge (line 420) | class SeaArtLongXLClipMerge:
method INPUT_TYPES (line 422) | def INPUT_TYPES(cls):
method do (line 431) | def do(self, clip_name, clip):
class SeaArtLongClip (line 450) | class SeaArtLongClip:
method INPUT_TYPES (line 452) | def INPUT_TYPES(cls):
method do (line 460) | def do(self, clip_name):
class LongCLIPTextEncodeFlux (line 472) | class LongCLIPTextEncodeFlux:
method INPUT_TYPES (line 474) | def INPUT_TYPES(cls):
method do (line 484) | def do(self, clip_name, clip):
FILE: long_clip_model/longclip.py
function _convert_image_to_rgb (line 36) | def _convert_image_to_rgb(image):
function _transform (line 40) | def _transform(n_px):
function load (line 51) | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.c...
function load_from_clip (line 144) | def load_from_clip(name: str, device: Union[str, torch.device] = "cuda" ...
function only_tokenize (line 321) | def only_tokenize(texts: Union[str, List[str]]) -> Union[torch.IntTensor...
function tokenize (line 330) | def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60...
FILE: long_clip_model/model_longclip.py
class Bottleneck (line 10) | class Bottleneck(nn.Module):
method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
method forward (line 42) | def forward(self, x: torch.Tensor):
class AttentionPool2d (line 58) | class AttentionPool2d(nn.Module):
method __init__ (line 59) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o...
method forward (line 68) | def forward(self, x):
class ModifiedResNet (line 94) | class ModifiedResNet(nn.Module):
method __init__ (line 102) | def __init__(self, layers, output_dim, heads, input_resolution=224, wi...
method _make_layer (line 129) | def _make_layer(self, planes, blocks, stride=1):
method forward (line 138) | def forward(self, x):
class LayerNorm (line 157) | class LayerNorm(nn.LayerNorm):
method forward (line 160) | def forward(self, x: torch.Tensor):
class QuickGELU (line 166) | class QuickGELU(nn.Module):
method forward (line 167) | def forward(self, x: torch.Tensor):
class ResidualAttentionBlock (line 171) | class ResidualAttentionBlock(nn.Module):
method __init__ (line 172) | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor ...
method attention (line 185) | def attention(self, x: torch.Tensor):
method forward (line 189) | def forward(self, x: torch.Tensor):
class Transformer (line 195) | class Transformer(nn.Module):
method __init__ (line 196) | def __init__(self, width: int, layers: int, heads: int, attn_mask: tor...
method forward (line 202) | def forward(self, x: torch.Tensor, intermediate_output=None, attn_mask...
class VisionTransformer (line 215) | class VisionTransformer(nn.Module):
method __init__ (line 216) | def __init__(self, input_resolution: int, patch_size: int, width: int,...
method forward (line 232) | def forward(self, x: torch.Tensor):
class CLIP (line 252) | class CLIP(nn.Module):
method __init__ (line 253) | def __init__(self,
method initialize_parameters (line 321) | def initialize_parameters(self):
method build_attention_mask (line 350) | def build_attention_mask(self):
method dtype (line 359) | def dtype(self):
method get_input_embeddings (line 362) | def get_input_embeddings(self):
method set_input_embeddings (line 365) | def set_input_embeddings(self, embeddings):
method encode_image (line 368) | def encode_image(self, image):
method encode_text (line 371) | def encode_text(self, text):
method encode_text_full (line 387) | def encode_text_full(self, text):
method forward (line 403) | def forward(self, input_tokens, attention_mask=None, intermediate_outp...
function convert_weights (line 432) | def convert_weights(model: nn.Module):
function build_model (line 456) | def build_model(state_dict: dict, load_from_clip: bool):
FILE: long_clip_model/simple_tokenizer.py
function default_bpe (line 11) | def default_bpe():
function bytes_to_unicode (line 16) | def bytes_to_unicode():
function get_pairs (line 38) | def get_pairs(word):
function basic_clean (line 50) | def basic_clean(text):
function whitespace_clean (line 56) | def whitespace_clean(text):
class SimpleTokenizer (line 62) | class SimpleTokenizer(object):
method __init__ (line 63) | def __init__(self, bpe_path: str = default_bpe()):
method bpe (line 80) | def bpe(self, token):
method encode (line 121) | def encode(self, text):
method decode (line 129) | def decode(self, tokens):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (100K chars).
[
{
"path": "README-CN.md",
"chars": 764,
"preview": "# ComfyUI-Long-CLIP\n本项目是long-clip的comfyui实现,目前支持clip-l的替换,对于SD1.5可以使用SeaArtLongClip模块加载后替换模型中原本的clip,token的长度由77扩大至248,经"
},
{
"path": "README.md",
"chars": 1678,
"preview": "# ComfyUI-Long-CLIP (Flux Suport Now)\nThis project implements the comfyui for long-clip, currently supporting the replac"
},
{
"path": "__init__.py",
"chars": 422,
"preview": "from . import long_clip as long_clip\n\nNODE_CLASS_MAPPINGS = {\n \"SeaArtLongClip\": long_clip.SeaArtLongClip,\n \"SeaAr"
},
{
"path": "long_clip.py",
"chars": 20318,
"preview": "from .long_clip_model import longclip\nimport os\nimport torch\nfrom comfy.sd import CLIP\nimport folder_paths\nfrom comfy.sd"
},
{
"path": "long_clip_model/longclip.py",
"chars": 15068,
"preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom typing import Any, Union, List\nfrom torch import nn\nimport t"
},
{
"path": "long_clip_model/model_longclip.py",
"chars": 20379,
"preview": "from collections import OrderedDict\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.fun"
},
{
"path": "long_clip_model/simple_tokenizer.py",
"chars": 4628,
"preview": "import gzip\nimport html\nimport os\nfrom functools import lru_cache\n\nimport ftfy\nimport regex as re\n\n\n@lru_cache()\ndef def"
},
{
"path": "workflow/flux-long.json",
"chars": 16075,
"preview": "{\n \"last_node_id\": 38,\n \"last_link_id\": 118,\n \"nodes\": [\n {\n \"id\": 17,\n \"type\": \"BasicScheduler\",\n "
},
{
"path": "workflow/sd1-5-long.json",
"chars": 6707,
"preview": "{\n \"last_node_id\": 10,\n \"last_link_id\": 19,\n \"nodes\": [\n {\n \"id\": 7,\n \"type\": \"CLIPTextEncode\",\n \"p"
},
{
"path": "workflow/sdxl-long.json",
"chars": 6938,
"preview": "{\n \"last_node_id\": 13,\n \"last_link_id\": 29,\n \"nodes\": [\n {\n \"id\": 7,\n \"type\": \"CLIPTextEncode\",\n \"p"
}
]
About this extraction
This page contains the full source code of the SeaArtLab/ComfyUI-Long-CLIP GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (90.8 KB), approximately 24.5k tokens, and a symbol index with 101 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.