From 61196d88576c95c1cd8535e881af48172d5af525 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 25 Nov 2024 05:00:23 -0500 Subject: [PATCH] Add option to inference the diffusion model in fp32 and fp64. --- comfy/cli_args.py | 6 ++++-- comfy/model_management.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 20b9f474..847f35ab 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpunet_group = parser.add_mutually_exclusive_group() -fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") -fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") +fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.") +fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.") +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") diff --git a/comfy/model_management.py b/comfy/model_management.py index fd493aff..a793cab3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -628,6 +628,10 @@ def maximum_vram_for_weights(device=None): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if model_params < 0: model_params = 1000000000000000000000 + if args.fp32_unet: + return torch.float32 + if args.fp64_unet: + return torch.float64 if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: @@ -674,7 +678,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor # None means no manual cast def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): - if weight_dtype == torch.float32: + if weight_dtype == torch.float32 or weight_dtype == torch.float64: return None fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)