Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired

This commit is contained in:
Jedrzej Kosinski 2024-11-11 09:05:07 -06:00
parent 638c4086a3
commit b12cc83c5b
2 changed files with 12 additions and 32 deletions

View File

@ -770,9 +770,7 @@ class ModelPatcher:
self.add_callback_with_key(call_type, None, callback)
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
if call_type not in self.callbacks:
raise Exception(f"Callback '{call_type}' is not recognized.")
c = self.callbacks[call_type].setdefault(key, [])
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def remove_callbacks_with_key(self, call_type: str, key: str):
@ -793,9 +791,7 @@ class ModelPatcher:
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
if wrapper_type not in self.wrappers:
raise Exception(f"Wrapper '{wrapper_type}' is not recognized.")
w = self.wrappers[wrapper_type].setdefault(key, [])
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str):

View File

@ -12,19 +12,11 @@ class CallbacksMP:
ON_INJECT_MODEL = "on_inject_model"
ON_EJECT_MODEL = "on_eject_model"
# callbacks dict is in the format:
# {"call_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_callbacks(cls):
return {
cls.ON_CLONE: {None: []},
cls.ON_LOAD: {None: []},
cls.ON_CLEANUP: {None: []},
cls.ON_PRE_RUN: {None: []},
cls.ON_PREPARE_STATE: {None: []},
cls.ON_APPLY_HOOKS: {None: []},
cls.ON_REGISTER_ALL_HOOK_PATCHES: {None: []},
cls.ON_INJECT_MODEL: {None: []},
cls.ON_EJECT_MODEL: {None: []},
}
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
@ -33,9 +25,7 @@ def add_callback_with_key(call_type: str, key: str, callback: Callable, transfor
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
callbacks: dict[str, dict[str, list]] = transformer_options.get("callbacks", {})
if call_type not in callbacks:
raise Exception(f"Callback '{call_type}' is not recognized.")
c = callbacks[call_type].setdefault(key, [])
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
@ -54,15 +44,11 @@ class WrappersMP:
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
# wrappers dict is in the format:
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_wrappers(cls):
return {
cls.OUTER_SAMPLE: {None: []},
cls.SAMPLER_SAMPLE: {None: []},
cls.CALC_COND_BATCH: {None: []},
cls.APPLY_MODEL: {None: []},
cls.DIFFUSION_MODEL: {None: []},
}
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
@ -71,9 +57,7 @@ def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transfo
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
wrappers: dict[str, dict[str, list]] = transformer_options.get("wrappers", {})
if wrapper_type not in wrappers:
raise Exception(f"Wrapper '{wrapper_type}' is not recognized.")
w = wrappers[wrapper_type].setdefault(key, [])
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):