Fix no longer working on old pytorch.

This commit is contained in:
comfyanonymous 2024-08-01 22:19:53 -04:00
parent ce9ac2fe05
commit 369f459b20
1 changed files with 7 additions and 2 deletions

View File

@ -826,9 +826,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): def load_unet(self, unet_name, weight_dtype):
weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype] dtype = None
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
unet_path = folder_paths.get_full_path("unet", unet_name) unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path, dtype=weight_dtype) model = comfy.sd.load_unet(unet_path, dtype=dtype)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader: