Use transformers CLIP instead of open_clip for SD2.x
This should make things a bit cleaner.
This commit is contained in:
parent
bf9ccffb17
commit
56d802e1f3
73
comfy/sd.py
73
comfy/sd.py
|
@ -40,6 +40,42 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
||||||
if ids.dtype == torch.float32:
|
if ids.dtype == torch.float32:
|
||||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||||
|
|
||||||
|
keys_to_replace = {
|
||||||
|
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||||
|
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||||
|
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||||
|
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
for x in keys_to_replace:
|
||||||
|
if x in sd:
|
||||||
|
sd[keys_to_replace[x]] = sd.pop(x)
|
||||||
|
|
||||||
|
resblock_to_replace = {
|
||||||
|
"ln_1": "layer_norm1",
|
||||||
|
"ln_2": "layer_norm2",
|
||||||
|
"mlp.c_fc": "mlp.fc1",
|
||||||
|
"mlp.c_proj": "mlp.fc2",
|
||||||
|
"attn.out_proj": "self_attn.out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
for resblock in range(24):
|
||||||
|
for x in resblock_to_replace:
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
|
||||||
|
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
|
||||||
|
if k in sd:
|
||||||
|
sd[k_to] = sd.pop(k)
|
||||||
|
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
|
||||||
|
if k_from in sd:
|
||||||
|
weights = sd.pop(k_from)
|
||||||
|
for x in range(3):
|
||||||
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||||
|
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
|
||||||
|
sd[k_to] = weights[1024*x:1024*(x + 1)]
|
||||||
|
|
||||||
for x in load_state_dict_to:
|
for x in load_state_dict_to:
|
||||||
x.load_state_dict(sd, strict=False)
|
x.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
@ -62,12 +98,6 @@ LORA_CLIP_MAP = {
|
||||||
"self_attn.out_proj": "self_attn_out_proj",
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
}
|
}
|
||||||
|
|
||||||
LORA_CLIP2_MAP = {
|
|
||||||
"mlp.c_fc": "mlp_fc1",
|
|
||||||
"mlp.c_proj": "mlp_fc2",
|
|
||||||
"attn.out_proj": "self_attn_out_proj",
|
|
||||||
}
|
|
||||||
|
|
||||||
LORA_UNET_MAP = {
|
LORA_UNET_MAP = {
|
||||||
"proj_in": "proj_in",
|
"proj_in": "proj_in",
|
||||||
"proj_out": "proj_out",
|
"proj_out": "proj_out",
|
||||||
|
@ -116,7 +146,7 @@ def model_lora_keys(model, key_map={}):
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
|
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
|
||||||
key_map[lora_key] = (k, 0)
|
key_map[lora_key] = k
|
||||||
up_counter += 1
|
up_counter += 1
|
||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
counter += 1
|
counter += 1
|
||||||
|
@ -124,7 +154,7 @@ def model_lora_keys(model, key_map={}):
|
||||||
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
|
||||||
key_map[lora_key] = (k, 0)
|
key_map[lora_key] = k
|
||||||
counter = 3
|
counter = 3
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
||||||
|
@ -133,29 +163,18 @@ def model_lora_keys(model, key_map={}):
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
|
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
|
||||||
key_map[lora_key] = (k, 0)
|
key_map[lora_key] = k
|
||||||
up_counter += 1
|
up_counter += 1
|
||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
counter += 1
|
counter += 1
|
||||||
counter = 0
|
counter = 0
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
for b in range(12):
|
for b in range(24):
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = (k, 0)
|
key_map[lora_key] = k
|
||||||
for b in range(24):
|
|
||||||
for c in LORA_CLIP2_MAP:
|
|
||||||
k = "model.transformer.resblocks.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP2_MAP[c])
|
|
||||||
key_map[lora_key] = (k, 0)
|
|
||||||
k = "model.transformer.resblocks.{}.attn.in_proj_weight".format(b)
|
|
||||||
if k in sdk:
|
|
||||||
key_map[text_model_lora_key.format(b, "self_attn_q_proj")] = (k, 0)
|
|
||||||
key_map[text_model_lora_key.format(b, "self_attn_k_proj")] = (k, 1)
|
|
||||||
key_map[text_model_lora_key.format(b, "self_attn_v_proj")] = (k, 2)
|
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
@ -174,7 +193,7 @@ class ModelPatcher:
|
||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k[0] in model_sd:
|
if k in model_sd:
|
||||||
p[k] = patches[k]
|
p[k] = patches[k]
|
||||||
self.patches += [(strength, p)]
|
self.patches += [(strength, p)]
|
||||||
return p.keys()
|
return p.keys()
|
||||||
|
@ -184,8 +203,7 @@ class ModelPatcher:
|
||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
for k in p[1]:
|
for k in p[1]:
|
||||||
v = p[1][k]
|
v = p[1][k]
|
||||||
key = k[0]
|
key = k
|
||||||
index = k[1]
|
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
print("could not patch. key doesn't exist in model:", k)
|
print("could not patch. key doesn't exist in model:", k)
|
||||||
continue
|
continue
|
||||||
|
@ -199,10 +217,7 @@ class ModelPatcher:
|
||||||
mat2 = v[1]
|
mat2 = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
calc = (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float()))
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
if len(weight.shape) > 2:
|
|
||||||
calc = calc.reshape(weight.shape)
|
|
||||||
weight[index * mat1.shape[0]:(index + 1) * mat1.shape[0]] += calc.type(weight.dtype).to(weight.device)
|
|
||||||
return self.model
|
return self.model
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
|
|
|
@ -1,86 +1,22 @@
|
||||||
import sd1_clip
|
import sd1_clip
|
||||||
import open_clip
|
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
class SD2ClipModel(torch.nn.Module, sd1_clip.ClipTokenWeightEncoder):
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
"""
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
|
||||||
Uses the OpenCLIP transformer encoder for text
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
"""
|
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
|
||||||
LAYERS = [
|
|
||||||
#"pooled",
|
|
||||||
"last",
|
|
||||||
"penultimate",
|
|
||||||
"hidden"
|
|
||||||
]
|
|
||||||
#version="laion2b_s32b_b79k"
|
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77,
|
|
||||||
freeze=True, layer="penultimate", layer_idx=None):
|
|
||||||
super().__init__()
|
|
||||||
assert layer in self.LAYERS
|
|
||||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'))
|
|
||||||
del model.visual
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
if freeze:
|
if layer == "last":
|
||||||
self.freeze()
|
layer_idx = -1
|
||||||
self.layer = layer
|
elif layer == "penultimate":
|
||||||
if self.layer == "last":
|
layer_idx = -2
|
||||||
self.layer_idx = 0
|
|
||||||
elif self.layer == "penultimate":
|
|
||||||
self.layer_idx = 1
|
|
||||||
elif self.layer == "hidden":
|
elif self.layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
assert abs(layer_idx) < 24
|
assert abs(layer_idx) < 24
|
||||||
self.clip_layer(layer_idx)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
self.clip_layer(layer_idx)
|
||||||
def freeze(self):
|
|
||||||
self.model = self.model.eval()
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
|
||||||
#layer_idx should have the same logic as the one for SD1
|
|
||||||
if abs(layer_idx) >= 24:
|
|
||||||
self.layer_idx = 0
|
|
||||||
else:
|
|
||||||
if layer_idx < 0:
|
|
||||||
self.layer_idx = -(layer_idx + 1)
|
|
||||||
else:
|
|
||||||
self.layer_idx = 24 - (layer_idx + 1)
|
|
||||||
|
|
||||||
def forward(self, tokens):
|
|
||||||
tokens = torch.LongTensor(tokens).to(self.device)
|
|
||||||
z = self.encode_with_transformer(tokens)
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode_with_transformer(self, tokens):
|
|
||||||
x = self.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
|
|
||||||
x = x + self.model.positional_embedding
|
|
||||||
x = x.permute(1, 0, 2) # NLD -> LND
|
|
||||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
|
||||||
x = x.permute(1, 0, 2) # LND -> NLD
|
|
||||||
x = self.model.ln_final(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
|
||||||
for i, r in enumerate(self.model.transformer.resblocks):
|
|
||||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
|
||||||
break
|
|
||||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
x = checkpoint(r, x, attn_mask)
|
|
||||||
else:
|
|
||||||
x = r(x, attn_mask=attn_mask)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, tokens):
|
|
||||||
return self(tokens)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, tokenizer_path=None):
|
def __init__(self, tokenizer_path=None):
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"CLIPTextModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"max_position_embeddings": 77,
|
||||||
|
"model_type": "clip_text_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"projection_dim": 512,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"vocab_size": 49408
|
||||||
|
}
|
Loading…
Reference in New Issue