Fix bundled embed.

This commit is contained in:
comfyanonymous 2024-08-07 13:30:45 -04:00
parent 17030fd4c0
commit e1c528196e
1 changed files with 12 additions and 15 deletions

View File

@ -313,17 +313,14 @@ def expand_directory_list(directories):
dirs.add(root) dirs.add(root)
return list(dirs) return list(dirs)
def bundled_embed(embed, key): #bundled embedding in lora format def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
i = 0 i = 0
out_list = [] out_list = []
while True: for k in embed:
i += 1 if k.startswith(prefix) and k.endswith(suffix):
k = key.format(i) out_list.append(embed[k])
w = embed.get(k, None) if len(out_list) == 0:
if w is None: return None
break
else:
out_list.append(w)
return torch.cat(out_list, dim=0) return torch.cat(out_list, dim=0)
@ -392,11 +389,11 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = torch.cat(out_list, dim=0) embed_out = torch.cat(out_list, dim=0)
elif embed_key is not None and embed_key in embed: elif embed_key is not None and embed_key in embed:
embed_out = embed[embed_key] embed_out = embed[embed_key]
elif 'bundle_emb.place1.string_to_param.*' in embed:
embed_out = bundled_embed(embed, 'bundle_emb.place{}.string_to_param.*')
elif 'bundle_emb.place1.{}'.format(embed_key) in embed:
embed_out = bundled_embed(embed, 'bundle_emb.place{}.{}'.format('{}', embed_key))
else: else:
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
if embed_out is None:
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
if embed_out is None:
values = embed.values() values = embed.values()
embed_out = next(iter(values)) embed_out = next(iter(values))
return embed_out return embed_out