Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
This commit is contained in:
parent
e80dc96627
commit
55014293b1
|
@ -19,6 +19,7 @@ class EnumHookMode(enum.Enum):
|
||||||
class EnumHookType(enum.Enum):
|
class EnumHookType(enum.Enum):
|
||||||
Weight = "weight"
|
Weight = "weight"
|
||||||
Patch = "patch"
|
Patch = "patch"
|
||||||
|
AddModel = "addmodel"
|
||||||
|
|
||||||
class EnumWeightTarget(enum.Enum):
|
class EnumWeightTarget(enum.Enum):
|
||||||
Model = "model"
|
Model = "model"
|
||||||
|
@ -121,10 +122,22 @@ class PatchHook(Hook):
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
subtype = type(self)
|
subtype = type(self)
|
||||||
c: PatchHook = super().clone(type(self))
|
c: PatchHook = super().clone(subtype)
|
||||||
c.patches = self.patches
|
c.patches = self.patches
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
class AddModelHook(Hook):
|
||||||
|
def __init__(self, model: 'ModelPatcher'):
|
||||||
|
super().__init__(hook_type=EnumHookType.AddModel)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def clone(self, subtype: Callable=None):
|
||||||
|
if subtype is None:
|
||||||
|
subtype = type(self)
|
||||||
|
c: AddModelHook = super().clone(subtype)
|
||||||
|
c.model = self.model
|
||||||
|
return c
|
||||||
|
|
||||||
class HookGroup:
|
class HookGroup:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hooks: List[Hook] = []
|
self.hooks: List[Hook] = []
|
||||||
|
|
|
@ -108,6 +108,8 @@ class CallbacksMP:
|
||||||
ON_PREPARE_STATE = "on_prepare_state"
|
ON_PREPARE_STATE = "on_prepare_state"
|
||||||
ON_APPLY_HOOKS = "on_apply_hooks"
|
ON_APPLY_HOOKS = "on_apply_hooks"
|
||||||
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
||||||
|
ON_INJECT_MODEL = "on_inject_model"
|
||||||
|
ON_EJECT_MODEL = "on_eject_model"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_callbacks(cls):
|
def init_callbacks(cls):
|
||||||
|
@ -119,8 +121,37 @@ class CallbacksMP:
|
||||||
cls.ON_PREPARE_STATE: [],
|
cls.ON_PREPARE_STATE: [],
|
||||||
cls.ON_APPLY_HOOKS: [],
|
cls.ON_APPLY_HOOKS: [],
|
||||||
cls.ON_REGISTER_ALL_HOOK_PATCHES: [],
|
cls.ON_REGISTER_ALL_HOOK_PATCHES: [],
|
||||||
|
cls.ON_INJECT_MODEL: [],
|
||||||
|
cls.ON_EJECT_MODEL: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class AutoPatcherEjector:
|
||||||
|
def __init__(self, model: 'ModelPatcher', skip_until_exit=False):
|
||||||
|
self.model = model
|
||||||
|
self.was_injected = False
|
||||||
|
self.prev_skip_injection = False
|
||||||
|
self.skip_until_exit = skip_until_exit
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.was_injected = False
|
||||||
|
self.prev_skip_injection = self.model.skip_injection
|
||||||
|
if self.skip_until_exit:
|
||||||
|
self.model.skip_injection = True
|
||||||
|
if self.model.is_injected:
|
||||||
|
self.model.eject_model()
|
||||||
|
self.was_injected = True
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
if self.was_injected:
|
||||||
|
if self.skip_until_exit or not self.model.skip_injection:
|
||||||
|
self.model.inject_model()
|
||||||
|
self.model.skip_injection = self.prev_skip_injection
|
||||||
|
|
||||||
|
class PatcherInjection:
|
||||||
|
def __init__(self, inject: Callable, eject: Callable):
|
||||||
|
self.inject = inject
|
||||||
|
self.eject = eject
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
@ -143,9 +174,13 @@ class ModelPatcher:
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
self.attachments: Dict[str] = {}
|
self.attachments: Dict[str] = {}
|
||||||
self.additional_models: list[ModelPatcher] = []
|
self.additional_models: Dict[str, List[ModelPatcher]] = {}
|
||||||
self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks()
|
self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks()
|
||||||
|
|
||||||
|
self.is_injected = False
|
||||||
|
self.skip_injection = False
|
||||||
|
self.injections: Dict[str, List[PatcherInjection]] = {}
|
||||||
|
|
||||||
self.hook_patches: Dict[comfy.hooks._HookRef] = {}
|
self.hook_patches: Dict[comfy.hooks._HookRef] = {}
|
||||||
self.hook_patches_backup: Dict[comfy.hooks._HookRef] = {}
|
self.hook_patches_backup: Dict[comfy.hooks._HookRef] = {}
|
||||||
self.hook_backup: Dict[str, Tuple[torch.Tensor, torch.device]] = {}
|
self.hook_backup: Dict[str, Tuple[torch.Tensor, torch.device]] = {}
|
||||||
|
@ -196,11 +231,16 @@ class ModelPatcher:
|
||||||
else:
|
else:
|
||||||
n.attachments[k] = self.attachments[k]
|
n.attachments[k] = self.attachments[k]
|
||||||
# additional models
|
# additional models
|
||||||
for m in self.additional_models:
|
for k, c in self.additional_models.items():
|
||||||
n.additional_models.append(m.clone())
|
n.additional_models[k] = [x.clone() for x in c]
|
||||||
# callbacks
|
# callbacks
|
||||||
for k, c in self.callbacks.items():
|
for k, c in self.callbacks.items():
|
||||||
n.callbacks[k] = c.copy()
|
n.callbacks[k] = c.copy()
|
||||||
|
# injection
|
||||||
|
n.is_injected = self.is_injected
|
||||||
|
n.skip_injection = self.skip_injection
|
||||||
|
for k, i in self.injections.items():
|
||||||
|
n.injections[k] = i.copy()
|
||||||
# hooks
|
# hooks
|
||||||
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
||||||
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
|
||||||
|
@ -342,27 +382,28 @@ class ModelPatcher:
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
p = set()
|
with self.use_ejected():
|
||||||
model_sd = self.model.state_dict()
|
p = set()
|
||||||
for k in patches:
|
model_sd = self.model.state_dict()
|
||||||
offset = None
|
for k in patches:
|
||||||
function = None
|
offset = None
|
||||||
if isinstance(k, str):
|
function = None
|
||||||
key = k
|
if isinstance(k, str):
|
||||||
else:
|
key = k
|
||||||
offset = k[1]
|
else:
|
||||||
key = k[0]
|
offset = k[1]
|
||||||
if len(k) > 2:
|
key = k[0]
|
||||||
function = k[2]
|
if len(k) > 2:
|
||||||
|
function = k[2]
|
||||||
|
|
||||||
if key in model_sd:
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches = self.patches.get(key, [])
|
current_patches = self.patches.get(key, [])
|
||||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||||
self.patches[key] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
|
@ -386,13 +427,14 @@ class ModelPatcher:
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
sd = self.model.state_dict()
|
with self.use_ejected():
|
||||||
keys = list(sd.keys())
|
sd = self.model.state_dict()
|
||||||
if filter_prefix is not None:
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
if filter_prefix is not None:
|
||||||
if not k.startswith(filter_prefix):
|
for k in keys:
|
||||||
sd.pop(k)
|
if not k.startswith(filter_prefix):
|
||||||
return sd
|
sd.pop(k)
|
||||||
|
return sd
|
||||||
|
|
||||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
|
@ -417,108 +459,116 @@ class ModelPatcher:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
self.unpatch_hooks()
|
with self.use_ejected():
|
||||||
mem_counter = 0
|
self.unpatch_hooks()
|
||||||
patch_counter = 0
|
mem_counter = 0
|
||||||
lowvram_counter = 0
|
patch_counter = 0
|
||||||
loading = []
|
lowvram_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
loading = []
|
||||||
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
for n, m in self.model.named_modules():
|
||||||
loading.append((comfy.model_management.module_size(m), n, m))
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
m = x[2]
|
m = x[2]
|
||||||
module_mem = x[0]
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
|
if lowvram_weight:
|
||||||
|
if weight_key in self.patches:
|
||||||
|
if force_patch_weights:
|
||||||
|
self.patch_weight_to_device(weight_key)
|
||||||
|
else:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
|
patch_counter += 1
|
||||||
|
if bias_key in self.patches:
|
||||||
|
if force_patch_weights:
|
||||||
|
self.patch_weight_to_device(bias_key)
|
||||||
|
else:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
else:
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
|
if hasattr(m, "weight"):
|
||||||
|
mem_counter += module_mem
|
||||||
|
load_completely.append((module_mem, n, m))
|
||||||
|
|
||||||
|
load_completely.sort(reverse=True)
|
||||||
|
for x in load_completely:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
|
if m.comfy_patched_weights == True:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||||
bias_key = "{}.bias".format(n)
|
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||||
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
if lowvram_weight:
|
for x in load_completely:
|
||||||
if weight_key in self.patches:
|
x[2].to(device_to)
|
||||||
if force_patch_weights:
|
|
||||||
self.patch_weight_to_device(weight_key)
|
|
||||||
else:
|
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
|
||||||
patch_counter += 1
|
|
||||||
if bias_key in self.patches:
|
|
||||||
if force_patch_weights:
|
|
||||||
self.patch_weight_to_device(bias_key)
|
|
||||||
else:
|
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
|
||||||
patch_counter += 1
|
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
if lowvram_counter > 0:
|
||||||
m.comfy_cast_weights = True
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
if m.comfy_cast_weights:
|
self.model.model_lowvram = False
|
||||||
wipe_lowvram_weight(m)
|
if full_load:
|
||||||
|
self.model.to(device_to)
|
||||||
|
mem_counter = self.model_size()
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
mem_counter += module_mem
|
self.model.device = device_to
|
||||||
load_completely.append((module_mem, n, m))
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
for callback in self.callbacks[CallbacksMP.ON_LOAD]:
|
||||||
for x in load_completely:
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||||
n = x[1]
|
|
||||||
m = x[2]
|
|
||||||
weight_key = "{}.weight".format(n)
|
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
if hasattr(m, "comfy_patched_weights"):
|
|
||||||
if m.comfy_patched_weights == True:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
self.apply_hooks(self.forced_hooks)
|
||||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
|
||||||
m.comfy_patched_weights = True
|
|
||||||
|
|
||||||
for x in load_completely:
|
|
||||||
x[2].to(device_to)
|
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
|
||||||
self.model.model_lowvram = True
|
|
||||||
else:
|
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
|
||||||
self.model.model_lowvram = False
|
|
||||||
if full_load:
|
|
||||||
self.model.to(device_to)
|
|
||||||
mem_counter = self.model_size()
|
|
||||||
|
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
|
||||||
self.model.device = device_to
|
|
||||||
self.model.model_loaded_weight_memory = mem_counter
|
|
||||||
self.apply_hooks(self.forced_hooks)
|
|
||||||
|
|
||||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
for k in self.object_patches:
|
with self.use_ejected():
|
||||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
for k in self.object_patches:
|
||||||
if k not in self.object_patches_backup:
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
self.object_patches_backup[k] = old
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
|
||||||
if lowvram_model_memory == 0:
|
if lowvram_model_memory == 0:
|
||||||
full_load = True
|
full_load = True
|
||||||
else:
|
else:
|
||||||
full_load = False
|
full_load = False
|
||||||
|
|
||||||
if load_weights:
|
if load_weights:
|
||||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
|
self.inject_model()
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
|
self.eject_model()
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
|
@ -555,66 +605,68 @@ class ModelPatcher:
|
||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
memory_freed = 0
|
with self.use_ejected():
|
||||||
patch_counter = 0
|
memory_freed = 0
|
||||||
unload_list = []
|
patch_counter = 0
|
||||||
|
unload_list = []
|
||||||
|
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
shift_lowvram = False
|
shift_lowvram = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
unload_list.append((module_mem, n, m))
|
unload_list.append((module_mem, n, m))
|
||||||
|
|
||||||
unload_list.sort()
|
unload_list.sort()
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free < memory_freed:
|
if memory_to_free < memory_freed:
|
||||||
break
|
break
|
||||||
module_mem = unload[0]
|
module_mem = unload[0]
|
||||||
n = unload[1]
|
n = unload[1]
|
||||||
m = unload[2]
|
m = unload[2]
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
for key in [weight_key, bias_key]:
|
for key in [weight_key, bias_key]:
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
self.backup.pop(key)
|
self.backup.pop(key)
|
||||||
|
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0):
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
self.unpatch_model(unpatch_weights=False)
|
with self.use_ejected(skip_injection=True):
|
||||||
self.patch_model(load_weights=False)
|
self.unpatch_model(unpatch_weights=False)
|
||||||
full_load = False
|
self.patch_model(load_weights=False)
|
||||||
if self.model.model_lowvram == False:
|
full_load = False
|
||||||
return 0
|
if self.model.model_lowvram == False:
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
return 0
|
||||||
full_load = True
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
current_used = self.model.model_loaded_weight_memory
|
full_load = True
|
||||||
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
current_used = self.model.model_loaded_weight_memory
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
@ -629,13 +681,49 @@ class ModelPatcher:
|
||||||
for callback in self.callbacks[CallbacksMP.ON_CLEANUP]:
|
for callback in self.callbacks[CallbacksMP.ON_CLEANUP]:
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def add_callback(self, key, callback: Callable):
|
def get_all_additional_models(self):
|
||||||
|
all_models = []
|
||||||
|
for models in self.additional_models.values():
|
||||||
|
all_models.extend(models)
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
def add_callback(self, key: str, callback: Callable):
|
||||||
if key not in self.callbacks:
|
if key not in self.callbacks:
|
||||||
raise Exception(f"Callback '{key}' is not recognized.")
|
raise Exception(f"Callback '{key}' is not recognized.")
|
||||||
self.callbacks[key].append(callback)
|
self.callbacks[key].append(callback)
|
||||||
|
|
||||||
def add_attachment(self, attachment):
|
def set_attachments(self, key: str, attachment):
|
||||||
self.attachments.append(attachment)
|
self.attachments[key] = attachment
|
||||||
|
|
||||||
|
def set_injections(self, key: str, injections: List[PatcherInjection]):
|
||||||
|
self.injections[key] = injections
|
||||||
|
|
||||||
|
def set_additional_models(self, key: str, models: List['ModelPatcher']):
|
||||||
|
self.additional_models[key] = models
|
||||||
|
|
||||||
|
def use_ejected(self, skip_injection=False):
|
||||||
|
return AutoPatcherEjector(self, skip_until_exit=skip_injection)
|
||||||
|
|
||||||
|
def inject_model(self):
|
||||||
|
if self.is_injected or self.skip_injection:
|
||||||
|
return
|
||||||
|
for injections in self.injections.values():
|
||||||
|
for inj in injections:
|
||||||
|
inj.inject(self)
|
||||||
|
self.is_injected = True
|
||||||
|
if self.is_injected:
|
||||||
|
for callback in self.callbacks[CallbacksMP.ON_INJECT_MODEL]:
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def eject_model(self):
|
||||||
|
if not self.is_injected:
|
||||||
|
return
|
||||||
|
for injections in self.injections.values():
|
||||||
|
for inj in injections:
|
||||||
|
inj.eject(self)
|
||||||
|
self.is_injected = False
|
||||||
|
for callback in self.callbacks[CallbacksMP.ON_EJECT_MODEL]:
|
||||||
|
callback(self)
|
||||||
|
|
||||||
def pre_run(self):
|
def pre_run(self):
|
||||||
for callback in self.callbacks[CallbacksMP.ON_PRE_RUN]:
|
for callback in self.callbacks[CallbacksMP.ON_PRE_RUN]:
|
||||||
|
@ -685,58 +773,60 @@ class ModelPatcher:
|
||||||
callback(self, hooks_dict, target)
|
callback(self, hooks_dict, target)
|
||||||
|
|
||||||
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False):
|
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False):
|
||||||
# NOTE: this mirrors behavior of add_patches func
|
with self.use_ejected():
|
||||||
if is_diff:
|
# NOTE: this mirrors behavior of add_patches func
|
||||||
comfy.model_management.unload_model_clones(self)
|
if is_diff:
|
||||||
current_hook_patches: Dict[str,List] = self.hook_patches.get(hook.hook_ref, {})
|
comfy.model_management.unload_model_clones(self)
|
||||||
p = set()
|
current_hook_patches: Dict[str,List] = self.hook_patches.get(hook.hook_ref, {})
|
||||||
model_sd = self.model.state_dict()
|
p = set()
|
||||||
for k in patches:
|
model_sd = self.model.state_dict()
|
||||||
offset = None
|
for k in patches:
|
||||||
function = None
|
offset = None
|
||||||
if isinstance(k, str):
|
function = None
|
||||||
key = k
|
if isinstance(k, str):
|
||||||
else:
|
key = k
|
||||||
offset = k[1]
|
|
||||||
key = k[0]
|
|
||||||
if len(k) > 2:
|
|
||||||
function = k[2]
|
|
||||||
|
|
||||||
if key in model_sd:
|
|
||||||
p.add(k)
|
|
||||||
current_patches: List[Tuple] = current_hook_patches.get(key, [])
|
|
||||||
if is_diff:
|
|
||||||
# take difference between desired weight and existing weight to get diff
|
|
||||||
# TODO: try to implement diff via strength_path/strength_model diff
|
|
||||||
model_dtype = comfy.utils.get_attr(self.model, key).dtype
|
|
||||||
if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
|
||||||
diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, key).to(torch.float32)).to(model_dtype)
|
|
||||||
else:
|
|
||||||
diff_weight = patches[k]-comfy.utils.get_attr(self.model, key)
|
|
||||||
current_patches.append((strength_patch, (diff_weight,), strength_model, offset, function))
|
|
||||||
else:
|
else:
|
||||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
offset = k[1]
|
||||||
current_hook_patches[key] = current_patches
|
key = k[0]
|
||||||
self.hook_patches[hook.hook_ref] = current_hook_patches
|
if len(k) > 2:
|
||||||
# since should care about these patches too to determine if same model, reroll patches_uuid
|
function = k[2]
|
||||||
self.patches_uuid = uuid.uuid4()
|
|
||||||
return list(p)
|
if key in model_sd:
|
||||||
|
p.add(k)
|
||||||
|
current_patches: List[Tuple] = current_hook_patches.get(key, [])
|
||||||
|
if is_diff:
|
||||||
|
# take difference between desired weight and existing weight to get diff
|
||||||
|
# TODO: try to implement diff via strength_path/strength_model diff
|
||||||
|
model_dtype = comfy.utils.get_attr(self.model, key).dtype
|
||||||
|
if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
||||||
|
diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, key).to(torch.float32)).to(model_dtype)
|
||||||
|
else:
|
||||||
|
diff_weight = patches[k]-comfy.utils.get_attr(self.model, key)
|
||||||
|
current_patches.append((strength_patch, (diff_weight,), strength_model, offset, function))
|
||||||
|
else:
|
||||||
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||||
|
current_hook_patches[key] = current_patches
|
||||||
|
self.hook_patches[hook.hook_ref] = current_hook_patches
|
||||||
|
# since should care about these patches too to determine if same model, reroll patches_uuid
|
||||||
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
return list(p)
|
||||||
|
|
||||||
def get_weight_diffs(self, patches):
|
def get_weight_diffs(self, patches):
|
||||||
comfy.model_management.unload_model_clones(self)
|
with self.use_ejected():
|
||||||
weights: Dict[str, Tuple] = {}
|
comfy.model_management.unload_model_clones(self)
|
||||||
p = set()
|
weights: Dict[str, Tuple] = {}
|
||||||
model_sd = self.model.state_dict()
|
p = set()
|
||||||
for k in patches:
|
model_sd = self.model.state_dict()
|
||||||
if k in model_sd:
|
for k in patches:
|
||||||
p.add(k)
|
if k in model_sd:
|
||||||
model_dtype = comfy.utils.get_attr(self.model, k).dtype
|
p.add(k)
|
||||||
if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
model_dtype = comfy.utils.get_attr(self.model, k).dtype
|
||||||
diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, k).to(torch.float32)).to(model_dtype)
|
if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
||||||
else:
|
diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, k).to(torch.float32)).to(model_dtype)
|
||||||
diff_weight = patches[k]-comfy.utils.get_attr(self.model, k)
|
else:
|
||||||
weights[k] = (diff_weight,)
|
diff_weight = patches[k]-comfy.utils.get_attr(self.model, k)
|
||||||
return weights, p
|
weights[k] = (diff_weight,)
|
||||||
|
return weights, p
|
||||||
|
|
||||||
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
|
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
|
||||||
# combined_patches will contain weights of all relevant hooks, per key
|
# combined_patches will contain weights of all relevant hooks, per key
|
||||||
|
@ -765,27 +855,28 @@ class ModelPatcher:
|
||||||
callback(self, hooks)
|
callback(self, hooks)
|
||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
self.unpatch_hooks()
|
with self.use_ejected():
|
||||||
model_sd = self.model_state_dict()
|
self.unpatch_hooks()
|
||||||
# if have cached weights for hooks, use it
|
model_sd = self.model_state_dict()
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
# if have cached weights for hooks, use it
|
||||||
if cached_weights is not None:
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
for key in cached_weights:
|
if cached_weights is not None:
|
||||||
if key not in model_sd:
|
for key in cached_weights:
|
||||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
if key not in model_sd:
|
||||||
continue
|
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key)
|
continue
|
||||||
else:
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key)
|
||||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
else:
|
||||||
original_weights = None
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
if len(relevant_patches) > 0:
|
original_weights = None
|
||||||
original_weights = self.get_key_patches()
|
if len(relevant_patches) > 0:
|
||||||
for key in relevant_patches:
|
original_weights = self.get_key_patches()
|
||||||
if key not in model_sd:
|
for key in relevant_patches:
|
||||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
if key not in model_sd:
|
||||||
continue
|
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights)
|
continue
|
||||||
self.current_hooks = hooks
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights)
|
||||||
|
self.current_hooks = hooks
|
||||||
|
|
||||||
def patch_cached_hook_weights(self, cached_weights: Dict, key: str):
|
def patch_cached_hook_weights(self, cached_weights: Dict, key: str):
|
||||||
if key not in self.hook_backup:
|
if key not in self.hook_backup:
|
||||||
|
@ -825,25 +916,26 @@ class ModelPatcher:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self) -> None:
|
||||||
if len(self.hook_backup) == 0:
|
with self.use_ejected():
|
||||||
self.current_hooks = None
|
if len(self.hook_backup) == 0:
|
||||||
return
|
self.current_hooks = None
|
||||||
keys = list(self.hook_backup.keys())
|
return
|
||||||
if self.weight_inplace_update:
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
if self.weight_inplace_update:
|
||||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # does not need to be cast; device already matches
|
for k in keys:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0])
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # does not need to be cast; device already matches
|
||||||
else:
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0])
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
else:
|
||||||
else:
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
for k in keys:
|
else:
|
||||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
for k in keys:
|
||||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0])
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
else:
|
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0])
|
||||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
def clean_hooks(self):
|
def clean_hooks(self):
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
|
|
@ -2,6 +2,11 @@ import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.controlnet import ControlBase
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
"""ensures noise mask is of proper dimensions"""
|
"""ensures noise mask is of proper dimensions"""
|
||||||
|
@ -15,9 +20,22 @@ def get_models_from_cond(cond, model_type):
|
||||||
models = []
|
models = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if model_type in c:
|
if model_type in c:
|
||||||
models += [c[model_type]]
|
if isinstance(c[model_type], list):
|
||||||
|
models += c[model_type]
|
||||||
|
else:
|
||||||
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def get_hooks_from_cond(cond, filter_types: List[comfy.hooks.EnumHookType]=None):
|
||||||
|
hooks: Dict[comfy.hooks.Hook, None] = {}
|
||||||
|
for c in cond:
|
||||||
|
if 'hooks' in c:
|
||||||
|
for hook in c['hooks'].hooks:
|
||||||
|
hook: comfy.hooks.Hook
|
||||||
|
if not filter_types or hook.hook_type in filter_types:
|
||||||
|
hooks[hook] = None
|
||||||
|
return hooks
|
||||||
|
|
||||||
def convert_cond(cond):
|
def convert_cond(cond):
|
||||||
out = []
|
out = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
|
@ -32,12 +50,16 @@ def convert_cond(cond):
|
||||||
|
|
||||||
def get_additional_models(conds, dtype):
|
def get_additional_models(conds, dtype):
|
||||||
"""loads additional models in conditioning"""
|
"""loads additional models in conditioning"""
|
||||||
cnets = []
|
cnets: List[ControlBase] = []
|
||||||
gligen = []
|
gligen = []
|
||||||
|
add_models = []
|
||||||
|
hooks: Dict[comfy.hooks.AddModelHook, None] = {}
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
cnets += get_models_from_cond(conds[k], "control")
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
|
hooks.update(get_hooks_from_cond(conds[k], [comfy.hooks.EnumHookType.AddModel]))
|
||||||
|
|
||||||
control_nets = set(cnets)
|
control_nets = set(cnets)
|
||||||
|
|
||||||
|
@ -48,7 +70,9 @@ def get_additional_models(conds, dtype):
|
||||||
inference_memory += m.inference_memory_requirements(dtype)
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
models = control_models + gligen
|
hook_models = [x.model for x in hooks]
|
||||||
|
models = control_models + gligen + add_models + hook_models
|
||||||
|
|
||||||
return models, inference_memory
|
return models, inference_memory
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
|
@ -58,10 +82,11 @@ def cleanup_additional_models(models):
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def prepare_sampling(model, noise_shape, conds):
|
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||||
device = model.load_device
|
device = model.load_device
|
||||||
real_model = None
|
real_model: 'BaseModel' = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
models += model.get_all_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||||
|
@ -79,12 +104,9 @@ def cleanup_models(conds, models):
|
||||||
|
|
||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model, conds):
|
def prepare_model_patcher(model: 'ModelPatcher', conds):
|
||||||
# check for hooks in conds - if not registered, see if can be applied
|
# check for hooks in conds - if not registered, see if can be applied
|
||||||
hooks = {}
|
hooks = {}
|
||||||
for k in conds:
|
for k in conds:
|
||||||
for cond in conds[k]:
|
hooks.update(get_hooks_from_cond(conds[k]))
|
||||||
if 'hooks' in cond:
|
|
||||||
for hook in cond['hooks'].hooks:
|
|
||||||
hooks[hook] = None
|
|
||||||
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model)
|
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model)
|
||||||
|
|
Loading…
Reference in New Issue