From e38c94228bce913c1a88f6776f6a21bd64926aec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 9 Oct 2024 19:43:17 -0400 Subject: [PATCH] Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node. This is used to load weights in fp8 and use fp8 matrix multiplication. --- comfy/model_base.py | 2 +- comfy/model_management.py | 13 ++++++++++++- comfy/ops.py | 6 +++++- comfy/sd.py | 5 ++++- comfy/supported_models_base.py | 1 + nodes.py | 5 ++++- 6 files changed, 27 insertions(+), 5 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9bfdb3b3..a98fee1d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False)) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) diff --git a/comfy/model_management.py b/comfy/model_management.py index a97d489d..09798bd0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) try: - logging.info("pytorch version: {}".format(torch.version.__version__)) + logging.info("pytorch version: {}".format(torch_version)) except: pass @@ -1065,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): + if not is_nvidia(): + return False + props = torch.cuda.get_device_properties(device) if props.major >= 9: return True @@ -1072,6 +1075,14 @@ def supports_fp8_compute(device=None): return False if props.minor < 9: return False + + if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3): + return False + + if WINDOWS: + if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4): + return False + return True def soft_empty_cache(force=False): diff --git a/comfy/ops.py b/comfy/ops.py index f9411ba5..c90e25ea 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -299,7 +299,11 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False): + if comfy.model_management.supports_fp8_compute(load_device): + if (fp8_optimizations or args.fast) and not disable_fast_fp8: + return fp8_ops + if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init if args.fast and not disable_fast_fp8: diff --git a/comfy/sd.py b/comfy/sd.py index a494e531..feb1138d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -433,6 +433,7 @@ def detect_te_model(sd): def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = state_dicts + class EmptyClass: pass @@ -592,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) - offload_device = model_management.unet_offload_device() model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model.load_model_weights(sd, diffusion_model_prefix) @@ -678,6 +678,9 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + if model_options.get("fp8_optimizations", False): + model_config.optimizations["fp8"] = True + model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 7a2152f9..68e89551 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,6 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None + optimizations = {"fp8": False} @classmethod def matches(s, unet_config, state_dict=None): diff --git a/nodes.py b/nodes.py index a4065c76..15a78352 100644 --- a/nodes.py +++ b/nodes.py @@ -861,7 +861,7 @@ class UNETLoader: @classmethod def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), - "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],) }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -872,6 +872,9 @@ class UNETLoader: model_options = {} if weight_dtype == "fp8_e4m3fn": model_options["dtype"] = torch.float8_e4m3fn + elif weight_dtype == "fp8_e4m3fn_fast": + model_options["dtype"] = torch.float8_e4m3fn + model_options["fp8_optimizations"] = True elif weight_dtype == "fp8_e5m2": model_options["dtype"] = torch.float8_e5m2