Safely load pickled embeds that don't load with weights_only=True.
This commit is contained in:
parent
334aab05e5
commit
04d9bc13af
|
@ -3,6 +3,7 @@ import os
|
|||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
|
@ -171,6 +172,26 @@ def unescape_important(text):
|
|||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
def safe_load_embed_zip(embed_path):
|
||||
with zipfile.ZipFile(embed_path) as myzip:
|
||||
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
|
||||
names.reverse()
|
||||
for n in names:
|
||||
with myzip.open(n) as myfile:
|
||||
data = myfile.read()
|
||||
number = len(data) // 4
|
||||
length_embed = 1024 #sd2.x
|
||||
if number < 768:
|
||||
continue
|
||||
if number % 768 == 0:
|
||||
length_embed = 768 #sd1.x
|
||||
num_embeds = number // length_embed
|
||||
embed = torch.frombuffer(data, dtype=torch.float)
|
||||
out = embed.reshape((num_embeds, length_embed)).clone()
|
||||
del embed
|
||||
return out
|
||||
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory):
|
|||
|
||||
embed_path = valid_file
|
||||
|
||||
embed_out = None
|
||||
|
||||
try:
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
else:
|
||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
try:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
except:
|
||||
embed_out = safe_load_embed_zip(embed_path)
|
||||
else:
|
||||
embed = torch.load(embed_path, map_location="cpu")
|
||||
except Exception as e:
|
||||
|
@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory):
|
|||
print("error loading embedding, skipping loading:", embedding_name)
|
||||
return None
|
||||
|
||||
if 'string_to_param' in embed:
|
||||
values = embed['string_to_param'].values()
|
||||
else:
|
||||
values = embed.values()
|
||||
return next(iter(values))
|
||||
if embed_out is None:
|
||||
if 'string_to_param' in embed:
|
||||
values = embed['string_to_param'].values()
|
||||
else:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
||||
|
|
Loading…
Reference in New Issue