Safely load pickled embeds that don't load with weights_only=True.
This commit is contained in:
parent
334aab05e5
commit
04d9bc13af
|
@ -3,6 +3,7 @@ import os
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
|
import zipfile
|
||||||
|
|
||||||
class ClipTokenWeightEncoder:
|
class ClipTokenWeightEncoder:
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
@ -171,6 +172,26 @@ def unescape_important(text):
|
||||||
text = text.replace("\0\2", "(")
|
text = text.replace("\0\2", "(")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def safe_load_embed_zip(embed_path):
|
||||||
|
with zipfile.ZipFile(embed_path) as myzip:
|
||||||
|
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
|
||||||
|
names.reverse()
|
||||||
|
for n in names:
|
||||||
|
with myzip.open(n) as myfile:
|
||||||
|
data = myfile.read()
|
||||||
|
number = len(data) // 4
|
||||||
|
length_embed = 1024 #sd2.x
|
||||||
|
if number < 768:
|
||||||
|
continue
|
||||||
|
if number % 768 == 0:
|
||||||
|
length_embed = 768 #sd1.x
|
||||||
|
num_embeds = number // length_embed
|
||||||
|
embed = torch.frombuffer(data, dtype=torch.float)
|
||||||
|
out = embed.reshape((num_embeds, length_embed)).clone()
|
||||||
|
del embed
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load_embed(embedding_name, embedding_directory):
|
def load_embed(embedding_name, embedding_directory):
|
||||||
if isinstance(embedding_directory, str):
|
if isinstance(embedding_directory, str):
|
||||||
embedding_directory = [embedding_directory]
|
embedding_directory = [embedding_directory]
|
||||||
|
@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory):
|
||||||
|
|
||||||
embed_path = valid_file
|
embed_path = valid_file
|
||||||
|
|
||||||
|
embed_out = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if embed_path.lower().endswith(".safetensors"):
|
if embed_path.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
if 'weights_only' in torch.load.__code__.co_varnames:
|
||||||
|
try:
|
||||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||||
|
except:
|
||||||
|
embed_out = safe_load_embed_zip(embed_path)
|
||||||
else:
|
else:
|
||||||
embed = torch.load(embed_path, map_location="cpu")
|
embed = torch.load(embed_path, map_location="cpu")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory):
|
||||||
print("error loading embedding, skipping loading:", embedding_name)
|
print("error loading embedding, skipping loading:", embedding_name)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if embed_out is None:
|
||||||
if 'string_to_param' in embed:
|
if 'string_to_param' in embed:
|
||||||
values = embed['string_to_param'].values()
|
values = embed['string_to_param'].values()
|
||||||
else:
|
else:
|
||||||
values = embed.values()
|
values = embed.values()
|
||||||
return next(iter(values))
|
embed_out = next(iter(values))
|
||||||
|
return embed_out
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
||||||
|
|
Loading…
Reference in New Issue