Only enable attention upcasting on models that actually need it.

This commit is contained in:
comfyanonymous 2024-05-14 15:18:00 -04:00
parent b0ab31d06c
commit bb4940d837
5 changed files with 27 additions and 24 deletions

View File

@ -207,12 +207,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
```embedding:embedding_filename.pt```
## How to increase generation speed?
On non Nvidia hardware you can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
```--dont-upcast-attention```
## How to show high-quality previews?
Use ```--preview-method auto``` to enable previews.

View File

@ -51,7 +51,6 @@ cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")

View File

@ -19,14 +19,6 @@ from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = None
else:
_ATTN_PRECISION = torch.float32
def exists(val):
return val is not None
@ -386,10 +378,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
@ -411,15 +404,15 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
@ -434,7 +427,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
@ -448,7 +441,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -588,7 +581,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=ops):
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
@ -606,7 +599,7 @@ class SpatialTransformer(nn.Module):
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)]
)
if not use_linear:
@ -662,6 +655,7 @@ class SpatialVideoTransformer(SpatialTransformer):
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops
):
super().__init__(
@ -674,6 +668,7 @@ class SpatialVideoTransformer(SpatialTransformer):
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
@ -703,6 +698,7 @@ class SpatialVideoTransformer(SpatialTransformer):
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)

View File

@ -431,6 +431,7 @@ class UNetModel(nn.Module):
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
device=None,
operations=ops,
):
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(

View File

@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
"use_temporal_attention": False,
}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
latent_format = latent_formats.SD15
def model_type(self, state_dict, prefix=""):
@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
"use_temporal_resblock": True
}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15