Cleaner CLIP text encoder implementation.
Use a simple CLIP model implementation instead of the one from transformers. This will allow some interesting things that would too hackish to implement using the transformers implementation.
This commit is contained in:
parent
2db86b4676
commit
fbdb14d4c4
|
@ -0,0 +1,126 @@
|
|||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.heads = heads
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.activation = ACTIVATIONS[activation]
|
||||
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=True)
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
#TODO: attention_mask
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.text_model(*args, **kwargs)
|
|
@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
if mask.dtype == torch.bool:
|
||||
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
sim += mask
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
@ -340,6 +343,18 @@ else:
|
|||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
|
||||
def optimized_attention_for_device(device, mask=False):
|
||||
if device == torch.device("cpu"): #TODO
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch
|
||||
else:
|
||||
return attention_basic
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
return optimized_attention
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import os
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||
from transformers import CLIPTokenizer
|
||||
import comfy.ops
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
import comfy.clip_model
|
||||
import json
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
|
@ -65,35 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = config_class.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with comfy.ops.use_comfy_ops(device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = model_class(config)
|
||||
|
||||
self.inner_name = inner_name
|
||||
if dtype is not None:
|
||||
inner_model = getattr(self.transformer, self.inner_name)
|
||||
if hasattr(inner_model, "embeddings"):
|
||||
embeddings_bak = inner_model.embeddings.to(torch.float32)
|
||||
inner_model.embeddings = None
|
||||
self.transformer.to(dtype)
|
||||
inner_model.embeddings = embeddings_bak
|
||||
else:
|
||||
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
|
||||
self.transformer.to(dtype)
|
||||
self.transformer.set_input_embeddings(previous_inputs)
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.transformer = model_class(config, dtype, device, comfy.ops)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
|
@ -108,7 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
|
||||
|
@ -119,7 +105,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) >= self.num_layers:
|
||||
if abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
|
@ -174,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
if self.transformer.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
|
@ -190,20 +176,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
if tokens[x, y] == max_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||
z = outputs[1]
|
||||
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
|
|
|
@ -3,13 +3,13 @@ import torch
|
|||
import os
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=23
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
|
|
|
@ -3,13 +3,13 @@ import torch
|
|||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
def load_sd(self, sd):
|
||||
|
@ -37,7 +37,7 @@ class SDXLTokenizer:
|
|||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
|
|
Loading…
Reference in New Issue