Reduce memory usage when applying DORA: #3557
This commit is contained in:
parent
58c9838274
commit
efa5a711b2
|
@ -9,7 +9,7 @@ import comfy.model_management
|
|||
from comfy.types import UnetWrapperFunction
|
||||
|
||||
|
||||
def apply_weight_decompose(dora_scale, weight):
|
||||
def weight_decompose_scale(dora_scale, weight):
|
||||
weight_norm = (
|
||||
weight.transpose(0, 1)
|
||||
.reshape(weight.shape[1], -1)
|
||||
|
@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight):
|
|||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
return weight * (dora_scale / weight_norm).type(weight.dtype)
|
||||
return (dora_scale / weight_norm).type(weight.dtype)
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
@ -365,7 +365,7 @@ class ModelPatcher:
|
|||
try:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
if dora_scale is not None:
|
||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
|
@ -407,7 +407,7 @@ class ModelPatcher:
|
|||
try:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
if dora_scale is not None:
|
||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
|
@ -439,7 +439,7 @@ class ModelPatcher:
|
|||
try:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
if dora_scale is not None:
|
||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
|
@ -456,7 +456,7 @@ class ModelPatcher:
|
|||
try:
|
||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||
if dora_scale is not None:
|
||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue