Enable bf16-vae by default on ampere and up.

This commit is contained in:
comfyanonymous 2023-08-27 23:06:19 -04:00
parent 1c794a2161
commit b8c7c770d3
2 changed files with 24 additions and 14 deletions

View File

@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@ -148,15 +148,27 @@ def is_nvidia():
return True
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16
except:
pass
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
@ -228,6 +240,7 @@ try:
except:
print("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE)
current_loaded_models = []
@ -448,12 +461,8 @@ def vae_offload_device():
return torch.device("cpu")
def vae_dtype():
if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
return torch.bfloat16
else:
return torch.float32
global VAE_DTYPE
return VAE_DTYPE
def get_autocast_device(dev):
if hasattr(dev, 'type'):