diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4f22f31..fda24543 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") -fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") +fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") diff --git a/comfy/model_management.py b/comfy/model_management.py index e5c80bf6..aca8af99 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -148,15 +148,27 @@ def is_nvidia(): return True ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +VAE_DTYPE = torch.float32 -if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - try: - if is_nvidia(): - torch_version = torch.version.__version__ - if int(torch_version[0]) >= 2: + +try: + if is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - except: - pass + if torch.cuda.is_bf16_supported(): + VAE_DTYPE = torch.bfloat16 +except: + pass + +if args.fp16_vae: + VAE_DTYPE = torch.float16 +elif args.bf16_vae: + VAE_DTYPE = torch.bfloat16 +elif args.fp32_vae: + VAE_DTYPE = torch.float32 + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) @@ -228,6 +240,7 @@ try: except: print("Could not pick default device.") +print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] @@ -448,12 +461,8 @@ def vae_offload_device(): return torch.device("cpu") def vae_dtype(): - if args.fp16_vae: - return torch.float16 - elif args.bf16_vae: - return torch.bfloat16 - else: - return torch.float32 + global VAE_DTYPE + return VAE_DTYPE def get_autocast_device(dev): if hasattr(dev, 'type'):