Use pytorch attention by default on nvidia when xformers isn't present.
Add a new argument --use-quad-cross-attention
This commit is contained in:
parent
9b93b920be
commit
8248babd44
|
@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum):
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
|
@ -139,7 +139,23 @@ else:
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
|
def is_nvidia():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.cuda:
|
||||||
|
return True
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
|
|
||||||
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
|
try:
|
||||||
|
if is_nvidia():
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
if int(torch_version[0]) >= 2:
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
@ -347,7 +363,7 @@ def pytorch_attention_flash_attention():
|
||||||
global ENABLE_PYTORCH_ATTENTION
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
#TODO: more reliable way of checking for flash attention?
|
#TODO: more reliable way of checking for flash attention?
|
||||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -438,7 +454,7 @@ def soft_empty_cache():
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue