260 lines
9.5 KiB
Python
260 lines
9.5 KiB
Python
import os
|
|
|
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
|
import torch
|
|
|
|
class ClipTokenWeightEncoder:
|
|
def encode_token_weights(self, token_weight_pairs):
|
|
z_empty = self.encode(self.empty_tokens)
|
|
output = []
|
|
for x in token_weight_pairs:
|
|
tokens = [list(map(lambda a: a[0], x))]
|
|
z = self.encode(tokens)
|
|
for i in range(len(z)):
|
|
for j in range(len(z[i])):
|
|
weight = x[j][1]
|
|
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
|
output += [z]
|
|
if (len(output) == 0):
|
|
return self.encode(self.empty_tokens)
|
|
return torch.cat(output, dim=-2)
|
|
|
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
|
LAYERS = [
|
|
"last",
|
|
"pooled",
|
|
"hidden"
|
|
]
|
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32
|
|
super().__init__()
|
|
assert layer in self.LAYERS
|
|
if textmodel_path is not None:
|
|
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
|
else:
|
|
if textmodel_json_config is None:
|
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
|
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
|
self.transformer = CLIPTextModel(config)
|
|
|
|
self.device = device
|
|
self.max_length = max_length
|
|
if freeze:
|
|
self.freeze()
|
|
self.layer = layer
|
|
self.layer_idx = None
|
|
self.empty_tokens = [[49406] + [49407] * 76]
|
|
if layer == "hidden":
|
|
assert layer_idx is not None
|
|
assert abs(layer_idx) <= 12
|
|
self.clip_layer(layer_idx)
|
|
|
|
def freeze(self):
|
|
self.transformer = self.transformer.eval()
|
|
#self.train = disabled_train
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
def clip_layer(self, layer_idx):
|
|
if abs(layer_idx) >= 12:
|
|
self.layer = "last"
|
|
else:
|
|
self.layer = "hidden"
|
|
self.layer_idx = layer_idx
|
|
|
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
|
out_tokens = []
|
|
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
|
embedding_weights = []
|
|
|
|
for x in tokens:
|
|
tokens_temp = []
|
|
for y in x:
|
|
if isinstance(y, int):
|
|
tokens_temp += [y]
|
|
else:
|
|
embedding_weights += [y]
|
|
tokens_temp += [next_new_token]
|
|
next_new_token += 1
|
|
out_tokens += [tokens_temp]
|
|
|
|
if len(embedding_weights) > 0:
|
|
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
|
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
|
n = token_dict_size
|
|
for x in embedding_weights:
|
|
new_embedding.weight[n] = x
|
|
n += 1
|
|
self.transformer.set_input_embeddings(new_embedding)
|
|
return out_tokens
|
|
|
|
def forward(self, tokens):
|
|
backup_embeds = self.transformer.get_input_embeddings()
|
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
|
tokens = torch.LongTensor(tokens).to(self.device)
|
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
|
self.transformer.set_input_embeddings(backup_embeds)
|
|
|
|
if self.layer == "last":
|
|
z = outputs.last_hidden_state
|
|
elif self.layer == "pooled":
|
|
z = outputs.pooler_output[:, None, :]
|
|
else:
|
|
z = outputs.hidden_states[self.layer_idx]
|
|
z = self.transformer.text_model.final_layer_norm(z)
|
|
|
|
return z
|
|
|
|
def encode(self, tokens):
|
|
return self(tokens)
|
|
|
|
def parse_parentheses(string):
|
|
result = []
|
|
current_item = ""
|
|
nesting_level = 0
|
|
for char in string:
|
|
if char == "(":
|
|
if nesting_level == 0:
|
|
if current_item:
|
|
result.append(current_item)
|
|
current_item = "("
|
|
else:
|
|
current_item = "("
|
|
else:
|
|
current_item += char
|
|
nesting_level += 1
|
|
elif char == ")":
|
|
nesting_level -= 1
|
|
if nesting_level == 0:
|
|
result.append(current_item + ")")
|
|
current_item = ""
|
|
else:
|
|
current_item += char
|
|
else:
|
|
current_item += char
|
|
if current_item:
|
|
result.append(current_item)
|
|
return result
|
|
|
|
def token_weights(string, current_weight):
|
|
a = parse_parentheses(string)
|
|
out = []
|
|
for x in a:
|
|
weight = current_weight
|
|
if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
|
|
x = x[1:-1]
|
|
xx = x.rfind(":")
|
|
weight *= 1.1
|
|
if xx > 0:
|
|
try:
|
|
weight = float(x[xx+1:])
|
|
x = x[:xx]
|
|
except:
|
|
pass
|
|
out += token_weights(x, weight)
|
|
else:
|
|
out += [(x, current_weight)]
|
|
return out
|
|
|
|
def escape_important(text):
|
|
text = text.replace("\\)", "\0\1")
|
|
text = text.replace("\\(", "\0\2")
|
|
return text
|
|
|
|
def unescape_important(text):
|
|
text = text.replace("\0\1", ")")
|
|
text = text.replace("\0\2", "(")
|
|
return text
|
|
|
|
def load_embed(embedding_name, embedding_directory):
|
|
embed_path = os.path.join(embedding_directory, embedding_name)
|
|
if not os.path.isfile(embed_path):
|
|
extensions = ['.safetensors', '.pt', '.bin']
|
|
valid_file = None
|
|
for x in extensions:
|
|
t = embed_path + x
|
|
if os.path.isfile(t):
|
|
valid_file = t
|
|
break
|
|
if valid_file is None:
|
|
print("warning, embedding {} does not exist, ignoring".format(embed_path))
|
|
return None
|
|
else:
|
|
embed_path = valid_file
|
|
|
|
if embed_path.lower().endswith(".safetensors"):
|
|
import safetensors.torch
|
|
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
|
else:
|
|
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
|
if 'string_to_param' in embed:
|
|
values = embed['string_to_param'].values()
|
|
else:
|
|
values = embed.values()
|
|
return next(iter(values))
|
|
|
|
class SD1Tokenizer:
|
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
|
if tokenizer_path is None:
|
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
|
self.max_length = max_length
|
|
self.max_tokens_per_section = self.max_length - 2
|
|
|
|
empty = self.tokenizer('')["input_ids"]
|
|
self.start_token = empty[0]
|
|
self.end_token = empty[1]
|
|
self.pad_with_end = pad_with_end
|
|
vocab = self.tokenizer.get_vocab()
|
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
|
self.embedding_directory = embedding_directory
|
|
self.max_word_length = 8
|
|
|
|
def tokenize_with_weights(self, text):
|
|
text = escape_important(text)
|
|
parsed_weights = token_weights(text, 1.0)
|
|
|
|
tokens = []
|
|
for t in parsed_weights:
|
|
to_tokenize = unescape_important(t[0]).split(' ')
|
|
for word in to_tokenize:
|
|
temp_tokens = []
|
|
embedding_identifier = "embedding:"
|
|
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
|
|
embedding_name = word[len(embedding_identifier):].strip('\n')
|
|
embed = load_embed(embedding_name, self.embedding_directory)
|
|
if embed is not None:
|
|
if len(embed.shape) == 1:
|
|
temp_tokens += [(embed, t[1])]
|
|
else:
|
|
for x in range(embed.shape[0]):
|
|
temp_tokens += [(embed[x], t[1])]
|
|
elif len(word) > 0:
|
|
tt = self.tokenizer(word)["input_ids"][1:-1]
|
|
for x in tt:
|
|
temp_tokens += [(x, t[1])]
|
|
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
|
|
|
|
#try not to split words in different sections
|
|
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
|
|
for x in range(tokens_left):
|
|
tokens += [(self.end_token, 1.0)]
|
|
tokens += temp_tokens
|
|
|
|
out_tokens = []
|
|
for x in range(0, len(tokens), self.max_tokens_per_section):
|
|
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
|
|
o_token += [(self.end_token, 1.0)]
|
|
if self.pad_with_end:
|
|
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
|
else:
|
|
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
|
|
|
|
out_tokens += [o_token]
|
|
|
|
return out_tokens
|
|
|
|
def untokenize(self, token_weight_pair):
|
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|