Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.

This is used to load weights in fp8 and use fp8 matrix multiplication.
This commit is contained in:
comfyanonymous 2024-10-09 19:43:17 -04:00
parent 203942c8b2
commit e38c94228b
6 changed files with 27 additions and 5 deletions

View File

@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try: try:
logging.info("pytorch version: {}".format(torch.version.__version__)) logging.info("pytorch version: {}".format(torch_version))
except: except:
pass pass
@ -1065,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
def supports_fp8_compute(device=None): def supports_fp8_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 9: if props.major >= 9:
return True return True
@ -1072,6 +1075,14 @@ def supports_fp8_compute(device=None):
return False return False
if props.minor < 9: if props.minor < 9:
return False return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
return False
return True return True
def soft_empty_cache(force=False): def soft_empty_cache(force=False):

View File

@ -299,7 +299,11 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
if comfy.model_management.supports_fp8_compute(load_device):
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype: if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init return disable_weight_init
if args.fast and not disable_fast_fp8: if args.fast and not disable_fast_fp8:

View File

@ -433,6 +433,7 @@ def detect_te_model(sd):
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts clip_data = state_dicts
class EmptyClass: class EmptyClass:
pass pass
@ -592,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix) model.load_model_weights(sd, diffusion_model_prefix)
@ -678,6 +678,9 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")

View File

@ -49,6 +49,7 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None custom_operations = None
optimizations = {"fp8": False}
@classmethod @classmethod
def matches(s, unet_config, state_dict=None): def matches(s, unet_config, state_dict=None):

View File

@ -861,7 +861,7 @@ class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet" FUNCTION = "load_unet"
@ -872,6 +872,9 @@ class UNETLoader:
model_options = {} model_options = {}
if weight_dtype == "fp8_e4m3fn": if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2