Add debug options to force on and off attention upcasting.
This commit is contained in:
parent
58f8388020
commit
46daf0a9a7
|
@ -95,6 +95,11 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
upcast = parser.add_mutually_exclusive_group()
|
||||||
|
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
||||||
|
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||||
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
|
|
|
@ -19,6 +19,14 @@ from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_precision(attn_precision):
|
||||||
|
if args.dont_upcast_attention:
|
||||||
|
return None
|
||||||
|
if attn_precision is None and args.force_upcast_attention:
|
||||||
|
return torch.float32
|
||||||
|
return attn_precision
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
@ -78,6 +86,8 @@ def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
@ -128,6 +138,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
|
||||||
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
b, _, dim_head = query.shape
|
b, _, dim_head = query.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
|
@ -188,6 +200,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
|
Loading…
Reference in New Issue