This commit is contained in:
comfyanonymous 2023-04-15 14:16:50 -04:00
commit 719c26c3c9
2 changed files with 88 additions and 45 deletions

View File

@ -372,10 +372,16 @@ class CLIP:
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def encode(self, text):
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode(self, text, from_tokens=False):
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
tokens = self.tokenizer.tokenize_with_weights(text)
if from_tokens:
tokens = text
else:
tokens = self.tokenizer.tokenize_with_weights(text)
try:
self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens)

View File

@ -260,60 +260,97 @@ class SD1Tokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
def tokenize_with_weights(self, text):
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
#tokenize words
tokens = []
for t in parsed_weights:
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
while len(to_tokenize) > 0:
word = to_tokenize.pop(0)
temp_tokens = []
embedding_identifier = "embedding:"
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(embedding_identifier):].strip('\n')
embed = load_embed(embedding_name, self.embedding_directory)
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
if embed is not None:
to_tokenize.insert(0, embedding_name[len(stripped):])
if embed is not None:
if len(embed.shape) == 1:
temp_tokens += [(embed, t[1])]
else:
for x in range(embed.shape[0]):
temp_tokens += [(embed[x], t[1])]
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
elif len(word) > 0:
tt = self.tokenizer(word)["input_ids"][1:-1]
for x in tt:
temp_tokens += [(x, t[1])]
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
#reshape token array to CLIP input size
batched_tokens = []
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
#try not to split words in different sections
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
for x in range(tokens_left):
tokens += [(self.end_token, 1.0)]
tokens += temp_tokens
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
#fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
out_tokens = []
for x in range(0, len(tokens), self.max_tokens_per_section):
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
o_token += [(self.end_token, 1.0)]
if self.pad_with_end:
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
else:
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
out_tokens += [o_token]
return batched_tokens
return out_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))