Disable xformers in VAE when xformers == 0.0.18

This commit is contained in:
comfyanonymous 2023-04-04 22:22:02 -04:00
parent af291e6f69
commit e46b1c3034
2 changed files with 16 additions and 2 deletions

View File

@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management
if model_management.xformers_enabled():
if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled() and attn_type == "vanilla":
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"

View File

@ -199,11 +199,25 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def xformers_enabled():
if vram_state == CPU:
return False
return XFORMERS_IS_AVAILBLE
def xformers_enabled_vae():
enabled = xformers_enabled()
if not enabled:
return False
try:
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
if xformers.version.__version__ == "0.0.18":
return False
except:
pass
return enabled
def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION