SD1 and SD2 clip and tokenizer code is now more similar to the SDXL one.
This commit is contained in:
parent
6ec3f12c6e
commit
e60ca6929a
|
@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
for b in range(32):
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}):
|
||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
clip_l_present = True
|
clip_l_present = True
|
||||||
|
|
|
@ -35,7 +35,7 @@ class ClipTokenWeightEncoder:
|
||||||
return z_empty.cpu(), first_pooled.cpu()
|
return z_empty.cpu(), first_pooled.cpu()
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
|
@ -342,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||||
embed_out = next(iter(values))
|
embed_out = next(iter(values))
|
||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
|
@ -454,3 +454,40 @@ class SD1Tokenizer:
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||||
|
|
||||||
|
|
||||||
|
class SD1Tokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||||
|
self.clip_name = clip_name
|
||||||
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
out = {}
|
||||||
|
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
|
||||||
|
class SD1ClipModel(torch.nn.Module):
|
||||||
|
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel):
|
||||||
|
super().__init__()
|
||||||
|
self.clip_name = clip_name
|
||||||
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def clip_layer(self, layer_idx):
|
||||||
|
getattr(self, self.clip).clip_layer(layer_idx)
|
||||||
|
|
||||||
|
def reset_clip_layer(self):
|
||||||
|
getattr(self, self.clip).reset_clip_layer()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||||
|
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||||
|
return out, pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from comfy import sd1_clip
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
|
@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
|
|
||||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||||
|
|
||||||
|
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||||
|
|
||||||
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from comfy import sd1_clip
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
class SDXLClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
|
@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||||
|
|
||||||
|
|
||||||
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||||
self.clip_l.layer_norm_hidden_state = False
|
self.clip_l.layer_norm_hidden_state = False
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
return self.clip_l.load_sd(sd)
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(torch.nn.Module):
|
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
|
||||||
self.clip_g.clip_layer(layer_idx)
|
|
||||||
|
|
||||||
def reset_clip_layer(self):
|
|
||||||
self.clip_g.reset_clip_layer()
|
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
|
||||||
token_weight_pairs_g = token_weight_pairs["g"]
|
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
|
||||||
return g_out, g_pooled
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
|
||||||
return self.clip_g.load_sd(sd)
|
|
||||||
|
|
|
@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE):
|
||||||
if ids.dtype == torch.float32:
|
if ids.dtype == torch.float32:
|
||||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||||
|
|
||||||
|
replace_prefix = {}
|
||||||
|
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||||
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||||
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||||
|
|
||||||
|
@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE):
|
||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
replace_prefix[""] = "cond_stage_model.model."
|
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
Loading…
Reference in New Issue