Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.
This is used to load weights in fp8 and use fp8 matrix multiplication.
This commit is contained in:
parent
203942c8b2
commit
e38c94228b
|
@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
|
|
@ -145,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
logging.info("pytorch version: {}".format(torch_version))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1065,6 +1065,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 9:
|
if props.major >= 9:
|
||||||
return True
|
return True
|
||||||
|
@ -1072,6 +1075,14 @@ def supports_fp8_compute(device=None):
|
||||||
return False
|
return False
|
||||||
if props.minor < 9:
|
if props.minor < 9:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if WINDOWS:
|
||||||
|
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
|
|
@ -299,7 +299,11 @@ class fp8_ops(manual_cast):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
|
||||||
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
|
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||||
|
return fp8_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast and not disable_fast_fp8:
|
if args.fast and not disable_fast_fp8:
|
||||||
|
|
|
@ -433,6 +433,7 @@ def detect_te_model(sd):
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -592,7 +593,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, diffusion_model_prefix)
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
|
@ -678,6 +678,9 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
|
if model_options.get("fp8_optimizations", False):
|
||||||
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
|
|
@ -49,6 +49,7 @@ class BASE:
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
|
5
nodes.py
5
nodes.py
|
@ -861,7 +861,7 @@ class UNETLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
|
@ -872,6 +872,9 @@ class UNETLoader:
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||||
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
|
model_options["fp8_optimizations"] = True
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue