SD1 and SD2 clip and tokenizer code is now more similar to the SDXL one.
This commit is contained in:
parent
6ec3f12c6e
commit
e60ca6929a
|
@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}):
|
|||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32):
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
|
@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}):
|
|||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
|
|
|
@ -35,7 +35,7 @@ class ClipTokenWeightEncoder:
|
|||
return z_empty.cpu(), first_pooled.cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
|
@ -342,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
|
@ -454,3 +454,40 @@ class SD1Tokenizer:
|
|||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel):
|
||||
super().__init__()
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype))
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
getattr(self, self.clip).clip_layer(layer_idx)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
getattr(self, self.clip).reset_clip_layer()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out, pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
|
|
@ -2,7 +2,7 @@ from comfy import sd1_clip
|
|||
import torch
|
||||
import os
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
|
@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
|
|||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel)
|
||||
|
|
|
@ -2,7 +2,7 @@ from comfy import sd1_clip
|
|||
import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
|
@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
|||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
|
@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
|||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l.layer_norm_hidden_state = False
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
|
@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module):
|
|||
else:
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
class SDXLRefinerClipModel(torch.nn.Module):
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.clip_g.clip_layer(layer_idx)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.clip_g.reset_clip_layer()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
return g_out, g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.clip_g.load_sd(sd)
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
|
|
|
@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE):
|
|||
if ids.dtype == torch.float32:
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
|
@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE):
|
|||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix[""] = "cond_stage_model.model."
|
||||
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
|
Loading…
Reference in New Issue