Enable bf16-vae by default on ampere and up.
This commit is contained in:
parent
1c794a2161
commit
b8c7c770d3
|
@ -54,7 +54,8 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
|
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||||
|
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
|
|
|
@ -148,15 +148,27 @@ def is_nvidia():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
if torch.cuda.is_bf16_supported():
|
||||||
pass
|
VAE_DTYPE = torch.bfloat16
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if args.fp16_vae:
|
||||||
|
VAE_DTYPE = torch.float16
|
||||||
|
elif args.bf16_vae:
|
||||||
|
VAE_DTYPE = torch.bfloat16
|
||||||
|
elif args.fp32_vae:
|
||||||
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
|
@ -228,6 +240,7 @@ try:
|
||||||
except:
|
except:
|
||||||
print("Could not pick default device.")
|
print("Could not pick default device.")
|
||||||
|
|
||||||
|
print("VAE dtype:", VAE_DTYPE)
|
||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
|
@ -448,12 +461,8 @@ def vae_offload_device():
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype():
|
def vae_dtype():
|
||||||
if args.fp16_vae:
|
global VAE_DTYPE
|
||||||
return torch.float16
|
return VAE_DTYPE
|
||||||
elif args.bf16_vae:
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
|
|
Loading…
Reference in New Issue