Mixed precision diffusion models with scaled fp8.
This change allows supports for diffusion models where all the linears are scaled fp8 while the other weights are the original precision.
This commit is contained in:
parent
83ca891118
commit
0075c6d096
|
@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
|
|||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
|
@ -246,8 +246,8 @@ class BaseModel(torch.nn.Module):
|
|||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([])
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
|
|
|
@ -288,8 +288,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
if "{}scaled_fp8".format(unet_key_prefix) in state_dict:
|
||||
model_config.scaled_fp8 = True
|
||||
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||
if scaled_fp8_weight is not None:
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
||||
return model_config
|
||||
|
||||
|
|
|
@ -334,10 +334,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||
|
||||
return scaled_fp8_op
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
|
||||
|
||||
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
|
|
|
@ -579,7 +579,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
|
@ -677,7 +677,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
if dtype is None:
|
||||
|
|
|
@ -49,7 +49,7 @@ class BASE:
|
|||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = False
|
||||
scaled_fp8 = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue