Only enable attention upcasting on models that actually need it.
This commit is contained in:
parent
b0ab31d06c
commit
bb4940d837
|
@ -207,12 +207,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||
```embedding:embedding_filename.pt```
|
||||
|
||||
|
||||
## How to increase generation speed?
|
||||
|
||||
On non Nvidia hardware you can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
|
||||
|
||||
```--dont-upcast-attention```
|
||||
|
||||
## How to show high-quality previews?
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
|
|
@ -51,7 +51,6 @@ cm_group = parser.add_mutually_exclusive_group()
|
|||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
|
|
|
@ -19,14 +19,6 @@ from comfy.cli_args import args
|
|||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# CrossAttn precision handling
|
||||
if args.dont_upcast_attention:
|
||||
logging.info("disabling upcasting of attention")
|
||||
_ATTN_PRECISION = None
|
||||
else:
|
||||
_ATTN_PRECISION = torch.float32
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
@ -386,10 +378,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
|
|||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
@ -411,15 +404,15 @@ class CrossAttention(nn.Module):
|
|||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION)
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
|
@ -434,7 +427,7 @@ class BasicTransformerBlock(nn.Module):
|
|||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
|
@ -448,7 +441,7 @@ class BasicTransformerBlock(nn.Module):
|
|||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
@ -588,7 +581,7 @@ class SpatialTransformer(nn.Module):
|
|||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
use_checkpoint=True, dtype=None, device=None, operations=ops):
|
||||
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim] * depth
|
||||
|
@ -606,7 +599,7 @@ class SpatialTransformer(nn.Module):
|
|||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
|
@ -662,6 +655,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
attn_precision=None,
|
||||
dtype=None, device=None, operations=ops
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -674,6 +668,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
|
@ -703,6 +698,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||
inner_dim=time_mix_inner_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
|
|
|
@ -431,6 +431,7 @@ class UNetModel(nn.Module):
|
|||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
|
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
|
|||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
attn_precision=attn_precision,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
|
|
|
@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
|
|||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
|
@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
|
|||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
|
Loading…
Reference in New Issue