diff --git a/nodes.py b/nodes.py index 1994f119..e296597c 100644 --- a/nodes.py +++ b/nodes.py @@ -826,9 +826,14 @@ class UNETLoader: CATEGORY = "advanced/loaders" def load_unet(self, unet_name, weight_dtype): - weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype] + dtype = None + if weight_dtype == "fp8_e4m3fn": + dtype = torch.float8_e4m3fn + elif weight_dtype == "fp8_e5m2": + dtype = torch.float8_e5m2 + unet_path = folder_paths.get_full_path("unet", unet_name) - model = comfy.sd.load_unet(unet_path, dtype=weight_dtype) + model = comfy.sd.load_unet(unet_path, dtype=dtype) return (model,) class CLIPLoader: