Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
This commit is contained in:
parent
638c4086a3
commit
b12cc83c5b
|
@ -770,9 +770,7 @@ class ModelPatcher:
|
|||
self.add_callback_with_key(call_type, None, callback)
|
||||
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
||||
if call_type not in self.callbacks:
|
||||
raise Exception(f"Callback '{call_type}' is not recognized.")
|
||||
c = self.callbacks[call_type].setdefault(key, [])
|
||||
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str):
|
||||
|
@ -793,9 +791,7 @@ class ModelPatcher:
|
|||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||
if wrapper_type not in self.wrappers:
|
||||
raise Exception(f"Wrapper '{wrapper_type}' is not recognized.")
|
||||
w = self.wrappers[wrapper_type].setdefault(key, [])
|
||||
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||
|
|
|
@ -12,19 +12,11 @@ class CallbacksMP:
|
|||
ON_INJECT_MODEL = "on_inject_model"
|
||||
ON_EJECT_MODEL = "on_eject_model"
|
||||
|
||||
# callbacks dict is in the format:
|
||||
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_callbacks(cls):
|
||||
return {
|
||||
cls.ON_CLONE: {None: []},
|
||||
cls.ON_LOAD: {None: []},
|
||||
cls.ON_CLEANUP: {None: []},
|
||||
cls.ON_PRE_RUN: {None: []},
|
||||
cls.ON_PREPARE_STATE: {None: []},
|
||||
cls.ON_APPLY_HOOKS: {None: []},
|
||||
cls.ON_REGISTER_ALL_HOOK_PATCHES: {None: []},
|
||||
cls.ON_INJECT_MODEL: {None: []},
|
||||
cls.ON_EJECT_MODEL: {None: []},
|
||||
}
|
||||
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
||||
|
@ -33,9 +25,7 @@ def add_callback_with_key(call_type: str, key: str, callback: Callable, transfor
|
|||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
callbacks: dict[str, dict[str, list]] = transformer_options.get("callbacks", {})
|
||||
if call_type not in callbacks:
|
||||
raise Exception(f"Callback '{call_type}' is not recognized.")
|
||||
c = callbacks[call_type].setdefault(key, [])
|
||||
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||
c.append(callback)
|
||||
|
||||
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
||||
|
@ -54,15 +44,11 @@ class WrappersMP:
|
|||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
# wrappers dict is in the format:
|
||||
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
||||
@classmethod
|
||||
def init_wrappers(cls):
|
||||
return {
|
||||
cls.OUTER_SAMPLE: {None: []},
|
||||
cls.SAMPLER_SAMPLE: {None: []},
|
||||
cls.CALC_COND_BATCH: {None: []},
|
||||
cls.APPLY_MODEL: {None: []},
|
||||
cls.DIFFUSION_MODEL: {None: []},
|
||||
}
|
||||
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
||||
return {}
|
||||
|
||||
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
||||
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
||||
|
@ -71,9 +57,7 @@ def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transfo
|
|||
if is_model_options:
|
||||
transformer_options = transformer_options.get("transformer_options", {})
|
||||
wrappers: dict[str, dict[str, list]] = transformer_options.get("wrappers", {})
|
||||
if wrapper_type not in wrappers:
|
||||
raise Exception(f"Wrapper '{wrapper_type}' is not recognized.")
|
||||
w = wrappers[wrapper_type].setdefault(key, [])
|
||||
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
||||
|
|
Loading…
Reference in New Issue