Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks

This commit is contained in:
kosinkadink1@gmail.com 2024-09-19 21:43:58 +09:00
parent e80dc96627
commit 55014293b1
3 changed files with 394 additions and 267 deletions

View File

@ -19,6 +19,7 @@ class EnumHookMode(enum.Enum):
class EnumHookType(enum.Enum): class EnumHookType(enum.Enum):
Weight = "weight" Weight = "weight"
Patch = "patch" Patch = "patch"
AddModel = "addmodel"
class EnumWeightTarget(enum.Enum): class EnumWeightTarget(enum.Enum):
Model = "model" Model = "model"
@ -121,10 +122,22 @@ class PatchHook(Hook):
def clone(self, subtype: Callable=None): def clone(self, subtype: Callable=None):
if subtype is None: if subtype is None:
subtype = type(self) subtype = type(self)
c: PatchHook = super().clone(type(self)) c: PatchHook = super().clone(subtype)
c.patches = self.patches c.patches = self.patches
return c return c
class AddModelHook(Hook):
def __init__(self, model: 'ModelPatcher'):
super().__init__(hook_type=EnumHookType.AddModel)
self.model = model
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelHook = super().clone(subtype)
c.model = self.model
return c
class HookGroup: class HookGroup:
def __init__(self): def __init__(self):
self.hooks: List[Hook] = [] self.hooks: List[Hook] = []

View File

@ -108,6 +108,8 @@ class CallbacksMP:
ON_PREPARE_STATE = "on_prepare_state" ON_PREPARE_STATE = "on_prepare_state"
ON_APPLY_HOOKS = "on_apply_hooks" ON_APPLY_HOOKS = "on_apply_hooks"
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches" ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
ON_INJECT_MODEL = "on_inject_model"
ON_EJECT_MODEL = "on_eject_model"
@classmethod @classmethod
def init_callbacks(cls): def init_callbacks(cls):
@ -119,8 +121,37 @@ class CallbacksMP:
cls.ON_PREPARE_STATE: [], cls.ON_PREPARE_STATE: [],
cls.ON_APPLY_HOOKS: [], cls.ON_APPLY_HOOKS: [],
cls.ON_REGISTER_ALL_HOOK_PATCHES: [], cls.ON_REGISTER_ALL_HOOK_PATCHES: [],
cls.ON_INJECT_MODEL: [],
cls.ON_EJECT_MODEL: [],
} }
class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_until_exit=False):
self.model = model
self.was_injected = False
self.prev_skip_injection = False
self.skip_until_exit = skip_until_exit
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
if self.skip_until_exit:
self.model.skip_injection = True
if self.model.is_injected:
self.model.eject_model()
self.was_injected = True
def __exit__(self, *args):
if self.was_injected:
if self.skip_until_exit or not self.model.skip_injection:
self.model.inject_model()
self.model.skip_injection = self.prev_skip_injection
class PatcherInjection:
def __init__(self, inject: Callable, eject: Callable):
self.inject = inject
self.eject = eject
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
@ -143,9 +174,13 @@ class ModelPatcher:
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self.attachments: Dict[str] = {} self.attachments: Dict[str] = {}
self.additional_models: list[ModelPatcher] = [] self.additional_models: Dict[str, List[ModelPatcher]] = {}
self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks() self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks()
self.is_injected = False
self.skip_injection = False
self.injections: Dict[str, List[PatcherInjection]] = {}
self.hook_patches: Dict[comfy.hooks._HookRef] = {} self.hook_patches: Dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: Dict[comfy.hooks._HookRef] = {} self.hook_patches_backup: Dict[comfy.hooks._HookRef] = {}
self.hook_backup: Dict[str, Tuple[torch.Tensor, torch.device]] = {} self.hook_backup: Dict[str, Tuple[torch.Tensor, torch.device]] = {}
@ -196,11 +231,16 @@ class ModelPatcher:
else: else:
n.attachments[k] = self.attachments[k] n.attachments[k] = self.attachments[k]
# additional models # additional models
for m in self.additional_models: for k, c in self.additional_models.items():
n.additional_models.append(m.clone()) n.additional_models[k] = [x.clone() for x in c]
# callbacks # callbacks
for k, c in self.callbacks.items(): for k, c in self.callbacks.items():
n.callbacks[k] = c.copy() n.callbacks[k] = c.copy()
# injection
n.is_injected = self.is_injected
n.skip_injection = self.skip_injection
for k, i in self.injections.items():
n.injections[k] = i.copy()
# hooks # hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches) n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
@ -342,6 +382,7 @@ class ModelPatcher:
return self.model.get_dtype() return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():
p = set() p = set()
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
for k in patches: for k in patches:
@ -386,6 +427,7 @@ class ModelPatcher:
return p return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):
with self.use_ejected():
sd = self.model.state_dict() sd = self.model.state_dict()
keys = list(sd.keys()) keys = list(sd.keys())
if filter_prefix is not None: if filter_prefix is not None:
@ -417,6 +459,7 @@ class ModelPatcher:
comfy.utils.set_attr_param(self.model, key, out_weight) comfy.utils.set_attr_param(self.model, key, out_weight)
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
@ -501,9 +544,14 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
for callback in self.callbacks[CallbacksMP.ON_LOAD]:
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
self.apply_hooks(self.forced_hooks) self.apply_hooks(self.forced_hooks)
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
with self.use_ejected():
for k in self.object_patches: for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup: if k not in self.object_patches_backup:
@ -516,9 +564,11 @@ class ModelPatcher:
if load_weights: if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
self.inject_model()
return self.model return self.model
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
if unpatch_weights: if unpatch_weights:
self.unpatch_hooks() self.unpatch_hooks()
if self.model.model_lowvram: if self.model.model_lowvram:
@ -555,6 +605,7 @@ class ModelPatcher:
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = [] unload_list = []
@ -605,6 +656,7 @@ class ModelPatcher:
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0): def partially_load(self, device_to, extra_memory=0):
with self.use_ejected(skip_injection=True):
self.unpatch_model(unpatch_weights=False) self.unpatch_model(unpatch_weights=False)
self.patch_model(load_weights=False) self.patch_model(load_weights=False)
full_load = False full_load = False
@ -629,13 +681,49 @@ class ModelPatcher:
for callback in self.callbacks[CallbacksMP.ON_CLEANUP]: for callback in self.callbacks[CallbacksMP.ON_CLEANUP]:
callback(self) callback(self)
def add_callback(self, key, callback: Callable): def get_all_additional_models(self):
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
def add_callback(self, key: str, callback: Callable):
if key not in self.callbacks: if key not in self.callbacks:
raise Exception(f"Callback '{key}' is not recognized.") raise Exception(f"Callback '{key}' is not recognized.")
self.callbacks[key].append(callback) self.callbacks[key].append(callback)
def add_attachment(self, attachment): def set_attachments(self, key: str, attachment):
self.attachments.append(attachment) self.attachments[key] = attachment
def set_injections(self, key: str, injections: List[PatcherInjection]):
self.injections[key] = injections
def set_additional_models(self, key: str, models: List['ModelPatcher']):
self.additional_models[key] = models
def use_ejected(self, skip_injection=False):
return AutoPatcherEjector(self, skip_until_exit=skip_injection)
def inject_model(self):
if self.is_injected or self.skip_injection:
return
for injections in self.injections.values():
for inj in injections:
inj.inject(self)
self.is_injected = True
if self.is_injected:
for callback in self.callbacks[CallbacksMP.ON_INJECT_MODEL]:
callback(self)
def eject_model(self):
if not self.is_injected:
return
for injections in self.injections.values():
for inj in injections:
inj.eject(self)
self.is_injected = False
for callback in self.callbacks[CallbacksMP.ON_EJECT_MODEL]:
callback(self)
def pre_run(self): def pre_run(self):
for callback in self.callbacks[CallbacksMP.ON_PRE_RUN]: for callback in self.callbacks[CallbacksMP.ON_PRE_RUN]:
@ -685,6 +773,7 @@ class ModelPatcher:
callback(self, hooks_dict, target) callback(self, hooks_dict, target)
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False): def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False):
with self.use_ejected():
# NOTE: this mirrors behavior of add_patches func # NOTE: this mirrors behavior of add_patches func
if is_diff: if is_diff:
comfy.model_management.unload_model_clones(self) comfy.model_management.unload_model_clones(self)
@ -723,6 +812,7 @@ class ModelPatcher:
return list(p) return list(p)
def get_weight_diffs(self, patches): def get_weight_diffs(self, patches):
with self.use_ejected():
comfy.model_management.unload_model_clones(self) comfy.model_management.unload_model_clones(self)
weights: Dict[str, Tuple] = {} weights: Dict[str, Tuple] = {}
p = set() p = set()
@ -765,6 +855,7 @@ class ModelPatcher:
callback(self, hooks) callback(self, hooks)
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
# if have cached weights for hooks, use it # if have cached weights for hooks, use it
@ -825,6 +916,7 @@ class ModelPatcher:
comfy.utils.set_attr_param(self.model, key, out_weight) comfy.utils.set_attr_param(self.model, key, out_weight)
def unpatch_hooks(self) -> None: def unpatch_hooks(self) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None
return return

View File

@ -2,6 +2,11 @@ import torch
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.hooks import comfy.hooks
from typing import TYPE_CHECKING, Dict, List
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" """ensures noise mask is of proper dimensions"""
@ -15,9 +20,22 @@ def get_models_from_cond(cond, model_type):
models = [] models = []
for c in cond: for c in cond:
if model_type in c: if model_type in c:
if isinstance(c[model_type], list):
models += c[model_type]
else:
models += [c[model_type]] models += [c[model_type]]
return models return models
def get_hooks_from_cond(cond, filter_types: List[comfy.hooks.EnumHookType]=None):
hooks: Dict[comfy.hooks.Hook, None] = {}
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
if not filter_types or hook.hook_type in filter_types:
hooks[hook] = None
return hooks
def convert_cond(cond): def convert_cond(cond):
out = [] out = []
for c in cond: for c in cond:
@ -32,12 +50,16 @@ def convert_cond(cond):
def get_additional_models(conds, dtype): def get_additional_models(conds, dtype):
"""loads additional models in conditioning""" """loads additional models in conditioning"""
cnets = [] cnets: List[ControlBase] = []
gligen = [] gligen = []
add_models = []
hooks: Dict[comfy.hooks.AddModelHook, None] = {}
for k in conds: for k in conds:
cnets += get_models_from_cond(conds[k], "control") cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen") gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
hooks.update(get_hooks_from_cond(conds[k], [comfy.hooks.EnumHookType.AddModel]))
control_nets = set(cnets) control_nets = set(cnets)
@ -48,7 +70,9 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype) inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
models = control_models + gligen hook_models = [x.model for x in hooks]
models = control_models + gligen + add_models + hook_models
return models, inference_memory return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
@ -58,10 +82,11 @@ def cleanup_additional_models(models):
m.cleanup() m.cleanup()
def prepare_sampling(model, noise_shape, conds): def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device device = model.load_device
real_model = None real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_all_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
@ -79,12 +104,9 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup)) cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model, conds): def prepare_model_patcher(model: 'ModelPatcher', conds):
# check for hooks in conds - if not registered, see if can be applied # check for hooks in conds - if not registered, see if can be applied
hooks = {} hooks = {}
for k in conds: for k in conds:
for cond in conds[k]: hooks.update(get_hooks_from_cond(conds[k]))
if 'hooks' in cond:
for hook in cond['hooks'].hooks:
hooks[hook] = None
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model) model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model)