Only enable attention upcasting on models that actually need it.

This commit is contained in:
comfyanonymous 2024-05-14 15:18:00 -04:00
parent b0ab31d06c
commit bb4940d837
5 changed files with 27 additions and 24 deletions

View File

@ -207,12 +207,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
```embedding:embedding_filename.pt``` ```embedding:embedding_filename.pt```
## How to increase generation speed?
On non Nvidia hardware you can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
```--dont-upcast-attention```
## How to show high-quality previews? ## How to show high-quality previews?
Use ```--preview-method auto``` to enable previews. Use ```--preview-method auto``` to enable previews.

View File

@ -51,7 +51,6 @@ cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
fp_group = parser.add_mutually_exclusive_group() fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")

View File

@ -19,14 +19,6 @@ from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = None
else:
_ATTN_PRECISION = torch.float32
def exists(val): def exists(val):
return val is not None return val is not None
@ -386,10 +378,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.attn_precision = attn_precision
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
@ -411,15 +404,15 @@ class CrossAttention(nn.Module):
v = self.to_v(context) v = self.to_v(context)
if mask is None: if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION) out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else: else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION) out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out) return self.to_out(out)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
self.ff_in = ff_in or inner_dim is not None self.ff_in = ff_in or inner_dim is not None
@ -434,7 +427,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention: if disable_temporal_crossattention:
@ -448,7 +441,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -588,7 +581,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=ops): use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
@ -606,7 +599,7 @@ class SpatialTransformer(nn.Module):
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
@ -662,6 +655,7 @@ class SpatialVideoTransformer(SpatialTransformer):
disable_self_attn=False, disable_self_attn=False,
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_time_embed_period: int = 10000, max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops dtype=None, device=None, operations=ops
): ):
super().__init__( super().__init__(
@ -674,6 +668,7 @@ class SpatialVideoTransformer(SpatialTransformer):
context_dim=context_dim, context_dim=context_dim,
use_linear=use_linear, use_linear=use_linear,
disable_self_attn=disable_self_attn, disable_self_attn=disable_self_attn,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
self.time_depth = time_depth self.time_depth = time_depth
@ -703,6 +698,7 @@ class SpatialVideoTransformer(SpatialTransformer):
inner_dim=time_mix_inner_dim, inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn, disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention, disable_temporal_crossattention=disable_temporal_crossattention,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(self.depth) for _ in range(self.depth)

View File

@ -431,6 +431,7 @@ class UNetModel(nn.Module):
video_kernel_size=None, video_kernel_size=None,
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_ddpm_temb_period=10000, max_ddpm_temb_period=10000,
attn_precision=None,
device=None, device=None,
operations=ops, operations=ops,
): ):
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
disable_self_attn=disable_self_attn, disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention, disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period, max_time_embed_period=max_ddpm_temb_period,
attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations dtype=self.dtype, device=device, operations=operations
) )
else: else:
return SpatialTransformer( return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim, ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
) )
def get_resblock( def get_resblock(

View File

@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
"use_temporal_attention": False, "use_temporal_attention": False,
} }
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
"use_temporal_resblock": True "use_temporal_resblock": True
} }
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15