Turn on safe load for a few models.

This commit is contained in:
comfyanonymous 2023-06-13 10:12:03 -04:00
parent 735ac4cf81
commit ff9b22d79e
1 changed files with 5 additions and 5 deletions

View File

@ -85,7 +85,7 @@ LORA_UNET_MAP_RESNET = {
} }
def load_lora(path, to_load): def load_lora(path, to_load):
lora = utils.load_torch_file(path) lora = utils.load_torch_file(path, safe_load=True)
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -722,7 +722,7 @@ class ControlNet:
return out return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False pth = False
sd2 = False sd2 = False
@ -924,7 +924,7 @@ class StyleModel:
def load_style_model(ckpt_path): def load_style_model(ckpt_path):
model_data = utils.load_torch_file(ckpt_path) model_data = utils.load_torch_file(ckpt_path, safe_load=True)
keys = model_data.keys() keys = model_data.keys()
if "style_embedding" in keys: if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
@ -935,7 +935,7 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_path, embedding_directory=None): def load_clip(ckpt_path, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path) clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
config = {} config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
@ -946,7 +946,7 @@ def load_clip(ckpt_path, embedding_directory=None):
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path) data = utils.load_torch_file(ckpt_path, safe_load=True)
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
if model_management.should_use_fp16(): if model_management.should_use_fp16():
model = model.half() model = model.half()