Safely load pickled embeds that don't load with weights_only=True.

This commit is contained in:
comfyanonymous 2023-04-14 15:33:43 -04:00
parent 334aab05e5
commit 04d9bc13af
1 changed files with 34 additions and 6 deletions

View File

@ -3,6 +3,7 @@ import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
import torch import torch
import traceback import traceback
import zipfile
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
@ -171,6 +172,26 @@ def unescape_important(text):
text = text.replace("\0\2", "(") text = text.replace("\0\2", "(")
return text return text
def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip:
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
names.reverse()
for n in names:
with myzip.open(n) as myfile:
data = myfile.read()
number = len(data) // 4
length_embed = 1024 #sd2.x
if number < 768:
continue
if number % 768 == 0:
length_embed = 768 #sd1.x
num_embeds = number // length_embed
embed = torch.frombuffer(data, dtype=torch.float)
out = embed.reshape((num_embeds, length_embed)).clone()
del embed
return out
def load_embed(embedding_name, embedding_directory): def load_embed(embedding_name, embedding_directory):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory):
embed_path = valid_file embed_path = valid_file
embed_out = None
try: try:
if embed_path.lower().endswith(".safetensors"): if embed_path.lower().endswith(".safetensors"):
import safetensors.torch import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu") embed = safetensors.torch.load_file(embed_path, device="cpu")
else: else:
if 'weights_only' in torch.load.__code__.co_varnames: if 'weights_only' in torch.load.__code__.co_varnames:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu") embed = torch.load(embed_path, weights_only=True, map_location="cpu")
except:
embed_out = safe_load_embed_zip(embed_path)
else: else:
embed = torch.load(embed_path, map_location="cpu") embed = torch.load(embed_path, map_location="cpu")
except Exception as e: except Exception as e:
@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory):
print("error loading embedding, skipping loading:", embedding_name) print("error loading embedding, skipping loading:", embedding_name)
return None return None
if embed_out is None:
if 'string_to_param' in embed: if 'string_to_param' in embed:
values = embed['string_to_param'].values() values = embed['string_to_param'].values()
else: else:
values = embed.values() values = embed.values()
return next(iter(values)) embed_out = next(iter(values))
return embed_out
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):