Disable xformers in VAE when xformers == 0.0.18
This commit is contained in:
parent
af291e6f69
commit
e46b1c3034
|
@ -9,7 +9,7 @@ from typing import Optional, Any
|
|||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||
import model_management
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if model_management.xformers_enabled() and attn_type == "vanilla":
|
||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-pytorch"
|
||||
|
|
|
@ -199,11 +199,25 @@ def get_autocast_device(dev):
|
|||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
if vram_state == CPU:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILBLE
|
||||
|
||||
|
||||
def xformers_enabled_vae():
|
||||
enabled = xformers_enabled()
|
||||
if not enabled:
|
||||
return False
|
||||
try:
|
||||
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
|
||||
if xformers.version.__version__ == "0.0.18":
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
return enabled
|
||||
|
||||
def pytorch_attention_enabled():
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
|
|
Loading…
Reference in New Issue