diff --git a/comfy/cli_args.py b/comfy/cli_args.py index f1306ef7..38718b66 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() -attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") +attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/model_management.py b/comfy/model_management.py index d64dce18..4e0e6a0a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -139,7 +139,23 @@ else: except: XFORMERS_IS_AVAILABLE = False +def is_nvidia(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.cuda: + return True + ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention + +if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + try: + if is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + ENABLE_PYTORCH_ATTENTION = True + except: + pass + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -347,7 +363,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? - if torch.version.cuda: #pytorch flash attention only works on Nvidia + if is_nvidia(): #pytorch flash attention only works on Nvidia return True return False @@ -438,7 +454,7 @@ def soft_empty_cache(): elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): - if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda + if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect()