Add DualClipLoader to load clip models for SDXL.

Update LoadClip to load clip models for SDXL refiner.
This commit is contained in:
comfyanonymous 2023-06-25 01:40:38 -04:00
parent b7933960bb
commit 20f579d91d
4 changed files with 67 additions and 11 deletions

View File

@ -19,6 +19,7 @@ from . import model_detection
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
@ -524,7 +525,7 @@ class CLIP:
return n
def load_from_state_dict(self, sd):
self.cond_stage_model.transformer.load_state_dict(sd, strict=False)
self.cond_stage_model.load_sd(sd)
def add_patches(self, patches, strength=1.0):
return self.patcher.add_patches(patches, strength)
@ -555,6 +556,8 @@ class CLIP:
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)
class VAE:
def __init__(self, ckpt_path=None, device=None, config=None):
@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
return StyleModel(model)
def load_clip(ckpt_path, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
def load_clip(ckpt_paths, embedding_directory=None):
clip_data = []
for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True))
class EmptyClass:
pass
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)
clip_target = EmptyClass()
clip_target.params = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
clip.load_from_state_dict(clip_data)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
print("clip missing:", m)
if len(u) > 0:
print("clip unexpected:", u)
return clip
def load_gligen(ckpt_path):

View File

@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
result = []
current_item = ""

View File

@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
self.layer = "hidden"
self.layer_idx = layer_idx
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
else:
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled
def load_sd(self, sd):
return self.clip_g.load_sd(sd)

View File

@ -520,11 +520,27 @@ class CLIPLoader:
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "loaders"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class DualCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class CLIPVisionLoader:
@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader,
"DualCLIPLoader": DualCLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning,