Fix bundled embed.
This commit is contained in:
parent
17030fd4c0
commit
e1c528196e
|
@ -313,17 +313,14 @@ def expand_directory_list(directories):
|
|||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
def bundled_embed(embed, key): #bundled embedding in lora format
|
||||
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
||||
i = 0
|
||||
out_list = []
|
||||
while True:
|
||||
i += 1
|
||||
k = key.format(i)
|
||||
w = embed.get(k, None)
|
||||
if w is None:
|
||||
break
|
||||
else:
|
||||
out_list.append(w)
|
||||
for k in embed:
|
||||
if k.startswith(prefix) and k.endswith(suffix):
|
||||
out_list.append(embed[k])
|
||||
if len(out_list) == 0:
|
||||
return None
|
||||
|
||||
return torch.cat(out_list, dim=0)
|
||||
|
||||
|
@ -392,11 +389,11 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||
embed_out = torch.cat(out_list, dim=0)
|
||||
elif embed_key is not None and embed_key in embed:
|
||||
embed_out = embed[embed_key]
|
||||
elif 'bundle_emb.place1.string_to_param.*' in embed:
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.place{}.string_to_param.*')
|
||||
elif 'bundle_emb.place1.{}'.format(embed_key) in embed:
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.place{}.{}'.format('{}', embed_key))
|
||||
else:
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
||||
if embed_out is None:
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
||||
if embed_out is None:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
|
Loading…
Reference in New Issue