Add support for textual inversion embedding for SD1.x CLIP.
This commit is contained in:
parent
702ac43d0c
commit
f73e57d881
|
@ -3,3 +3,4 @@ __pycache__/
|
|||
output/
|
||||
models/checkpoints
|
||||
models/vae
|
||||
models/embeddings
|
||||
|
|
|
@ -66,6 +66,10 @@ Dragging a generated png on the webpage or loading one will give you the full wo
|
|||
|
||||
You can use () to change emphasis of a word or phrase like: (good code:1.2) or (bad code:0.8). The default emphasis for () is 1.1. To use () characters in your actual prompt escape them like \\( or \\).
|
||||
|
||||
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
|
||||
|
||||
```embedding:embedding_filename.pt```
|
||||
|
||||
### Colab Notebook
|
||||
|
||||
To run it on colab you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb)
|
||||
|
|
22
comfy/sd.py
22
comfy/sd.py
|
@ -53,19 +53,25 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, embedding_directory=None):
|
||||
self.target_clip = config["target"]
|
||||
if "params" in config:
|
||||
params = config["params"]
|
||||
else:
|
||||
params = {}
|
||||
|
||||
tokenizer_params = {}
|
||||
|
||||
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
||||
clip = sd2_clip.SD2ClipModel
|
||||
tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
||||
clip = sd1_clip.SD1ClipModel
|
||||
tokenizer = sd1_clip.SD1Tokenizer
|
||||
if "params" in config:
|
||||
self.cond_stage_model = clip(**(config["params"]))
|
||||
else:
|
||||
self.cond_stage_model = clip()
|
||||
self.tokenizer = tokenizer()
|
||||
tokenizer_params['embedding_directory'] = embedding_directory
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
self.tokenizer = tokenizer(**(tokenizer_params))
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.tokenizer.tokenize_with_weights(text)
|
||||
|
@ -103,7 +109,7 @@ class VAE:
|
|||
return samples
|
||||
|
||||
|
||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
|
||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
config = OmegaConf.load(config_path)
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
|
@ -124,7 +130,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
|
|||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config)
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
|
|
|
@ -63,9 +63,38 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||
embedding_weights = []
|
||||
|
||||
for x in tokens:
|
||||
tokens_temp = []
|
||||
for y in x:
|
||||
if isinstance(y, int):
|
||||
tokens_temp += [y]
|
||||
else:
|
||||
embedding_weights += [y]
|
||||
tokens_temp += [next_new_token]
|
||||
next_new_token += 1
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||
n = token_dict_size
|
||||
for x in embedding_weights:
|
||||
new_embedding.weight[n] = x
|
||||
n += 1
|
||||
self.transformer.set_input_embeddings(new_embedding)
|
||||
return out_tokens
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
|
@ -138,18 +167,49 @@ def unescape_important(text):
|
|||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
embed_path = os.path.join(embedding_directory, embedding_name)
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
valid_file = None
|
||||
for x in extensions:
|
||||
t = embed_path + x
|
||||
if os.path.isfile(t):
|
||||
valid_file = t
|
||||
break
|
||||
if valid_file is None:
|
||||
print("warning, embedding {} does not exist, ignoring".format(embed_path))
|
||||
return None
|
||||
else:
|
||||
embed_path = valid_file
|
||||
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
else:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
if 'string_to_param' in embed:
|
||||
values = embed['string_to_param'].values()
|
||||
else:
|
||||
values = embed.values()
|
||||
return next(iter(values))
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.max_tokens_per_section = self.max_length - 2
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
self.pad_with_end = pad_with_end
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.embedding_directory = embedding_directory
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize_with_weights(self, text):
|
||||
text = escape_important(text)
|
||||
|
@ -157,13 +217,34 @@ class SD1Tokenizer:
|
|||
|
||||
tokens = []
|
||||
for t in parsed_weights:
|
||||
tt = self.tokenizer(unescape_important(t[0]))["input_ids"][1:-1]
|
||||
to_tokenize = unescape_important(t[0]).split(' ')
|
||||
for word in to_tokenize:
|
||||
temp_tokens = []
|
||||
embedding_identifier = "embedding:"
|
||||
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(embedding_identifier):].strip('\n')
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
if embed is not None:
|
||||
if len(embed.shape) == 1:
|
||||
temp_tokens += [(embed, t[1])]
|
||||
else:
|
||||
for x in range(embed.shape[0]):
|
||||
temp_tokens += [(embed[x], t[1])]
|
||||
elif len(word) > 0:
|
||||
tt = self.tokenizer(word)["input_ids"][1:-1]
|
||||
for x in tt:
|
||||
tokens += [(x, t[1])]
|
||||
temp_tokens += [(x, t[1])]
|
||||
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
|
||||
|
||||
#try not to split words in different sections
|
||||
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
|
||||
for x in range(tokens_left):
|
||||
tokens += [(self.end_token, 1.0)]
|
||||
tokens += temp_tokens
|
||||
|
||||
out_tokens = []
|
||||
for x in range(0, len(tokens), self.max_length - 2):
|
||||
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_length - 2 + x, len(tokens))]
|
||||
for x in range(0, len(tokens), self.max_tokens_per_section):
|
||||
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
|
||||
o_token += [(self.end_token, 1.0)]
|
||||
if self.pad_with_end:
|
||||
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
||||
|
|
3
nodes.py
3
nodes.py
|
@ -127,7 +127,8 @@ class CheckpointLoader:
|
|||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||
config_path = os.path.join(self.config_dir, config_name)
|
||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True)
|
||||
embedding_directory = os.path.join(self.models_dir, "embeddings")
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
|
||||
|
||||
class VAELoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
|
|
Loading…
Reference in New Issue