2024-11-03 03:21:16 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
class CallbacksMP:
|
|
|
|
ON_CLONE = "on_clone"
|
|
|
|
ON_LOAD = "on_load_after"
|
2024-11-14 14:06:02 +00:00
|
|
|
ON_DETACH = "on_detach_after"
|
2024-11-03 03:21:16 +00:00
|
|
|
ON_CLEANUP = "on_cleanup"
|
|
|
|
ON_PRE_RUN = "on_pre_run"
|
|
|
|
ON_PREPARE_STATE = "on_prepare_state"
|
|
|
|
ON_APPLY_HOOKS = "on_apply_hooks"
|
|
|
|
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
|
|
|
ON_INJECT_MODEL = "on_inject_model"
|
|
|
|
ON_EJECT_MODEL = "on_eject_model"
|
|
|
|
|
2024-11-11 15:05:07 +00:00
|
|
|
# callbacks dict is in the format:
|
|
|
|
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
2024-11-03 03:21:16 +00:00
|
|
|
@classmethod
|
2024-11-11 15:05:07 +00:00
|
|
|
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
|
|
|
return {}
|
2024-11-03 03:21:16 +00:00
|
|
|
|
2024-11-03 12:22:48 +00:00
|
|
|
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
|
|
|
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
|
|
|
|
|
|
|
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
|
|
|
if is_model_options:
|
2024-11-14 14:06:02 +00:00
|
|
|
transformer_options = transformer_options.setdefault("transformer_options", {})
|
|
|
|
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
2024-11-11 15:05:07 +00:00
|
|
|
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
2024-11-03 12:22:48 +00:00
|
|
|
c.append(callback)
|
|
|
|
|
2024-11-14 14:06:02 +00:00
|
|
|
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
|
|
|
if is_model_options:
|
|
|
|
transformer_options = transformer_options.get("transformer_options", {})
|
|
|
|
c_list = []
|
|
|
|
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
|
|
|
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
|
|
|
return c_list
|
|
|
|
|
2024-11-03 12:22:48 +00:00
|
|
|
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
|
|
|
if is_model_options:
|
|
|
|
transformer_options = transformer_options.get("transformer_options", {})
|
|
|
|
c_list = []
|
|
|
|
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
|
|
|
for c in callbacks.get(call_type, {}).values():
|
|
|
|
c_list.extend(c)
|
|
|
|
return c_list
|
|
|
|
|
2024-11-03 03:21:16 +00:00
|
|
|
class WrappersMP:
|
|
|
|
OUTER_SAMPLE = "outer_sample"
|
|
|
|
SAMPLER_SAMPLE = "sampler_sample"
|
|
|
|
CALC_COND_BATCH = "calc_cond_batch"
|
|
|
|
APPLY_MODEL = "apply_model"
|
|
|
|
DIFFUSION_MODEL = "diffusion_model"
|
|
|
|
|
2024-11-11 15:05:07 +00:00
|
|
|
# wrappers dict is in the format:
|
|
|
|
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
2024-11-03 03:21:16 +00:00
|
|
|
@classmethod
|
2024-11-11 15:05:07 +00:00
|
|
|
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
|
|
|
return {}
|
2024-11-03 03:21:16 +00:00
|
|
|
|
2024-11-03 12:22:48 +00:00
|
|
|
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
|
|
|
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
|
|
|
|
|
|
|
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
|
|
|
if is_model_options:
|
2024-11-14 14:06:02 +00:00
|
|
|
transformer_options = transformer_options.setdefault("transformer_options", {})
|
|
|
|
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
2024-11-11 15:05:07 +00:00
|
|
|
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
2024-11-03 12:22:48 +00:00
|
|
|
w.append(wrapper)
|
|
|
|
|
2024-11-14 14:06:02 +00:00
|
|
|
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
|
|
|
if is_model_options:
|
|
|
|
transformer_options = transformer_options.get("transformer_options", {})
|
|
|
|
w_list = []
|
|
|
|
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
|
|
|
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
|
|
|
return w_list
|
|
|
|
|
2024-11-03 12:22:48 +00:00
|
|
|
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
|
|
|
if is_model_options:
|
|
|
|
transformer_options = transformer_options.get("transformer_options", {})
|
|
|
|
w_list = []
|
|
|
|
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
|
|
|
for w in wrappers.get(wrapper_type, {}).values():
|
|
|
|
w_list.extend(w)
|
|
|
|
return w_list
|
|
|
|
|
2024-11-03 03:21:16 +00:00
|
|
|
class WrapperExecutor:
|
|
|
|
"""Handles call stack of wrappers around a function in an ordered manner."""
|
|
|
|
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
2024-11-04 11:46:27 +00:00
|
|
|
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
|
|
|
# the class instance at runtime via executor.class_obj
|
2024-11-03 03:21:16 +00:00
|
|
|
self.original = original
|
|
|
|
self.class_obj = class_obj
|
|
|
|
self.wrappers = wrappers.copy()
|
|
|
|
self.idx = idx
|
|
|
|
self.is_last = idx == len(wrappers)
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
2024-11-04 11:46:27 +00:00
|
|
|
"""Calls the next wrapper or original function, whichever is appropriate."""
|
2024-11-03 03:21:16 +00:00
|
|
|
new_executor = self._create_next_executor()
|
|
|
|
return new_executor.execute(*args, **kwargs)
|
|
|
|
|
|
|
|
def execute(self, *args, **kwargs):
|
|
|
|
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
|
|
|
args = list(args)
|
|
|
|
kwargs = dict(kwargs)
|
|
|
|
if self.is_last:
|
|
|
|
return self.original(*args, **kwargs)
|
|
|
|
return self.wrappers[self.idx](self, *args, **kwargs)
|
|
|
|
|
|
|
|
def _create_next_executor(self) -> 'WrapperExecutor':
|
|
|
|
new_idx = self.idx + 1
|
|
|
|
if new_idx > len(self.wrappers):
|
|
|
|
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
|
|
|
if self.class_obj is None:
|
|
|
|
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
|
|
|
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
|
|
|
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
|
|
|
return cls(original, class_obj, wrappers, idx=idx)
|
|
|
|
|
|
|
|
class PatcherInjection:
|
|
|
|
def __init__(self, inject: Callable, eject: Callable):
|
|
|
|
self.inject = inject
|
|
|
|
self.eject = eject
|
2024-11-03 12:22:48 +00:00
|
|
|
|
|
|
|
def copy_nested_dicts(input_dict: dict):
|
|
|
|
new_dict = input_dict.copy()
|
|
|
|
for key, value in input_dict.items():
|
|
|
|
if isinstance(value, dict):
|
|
|
|
new_dict[key] = copy_nested_dicts(value)
|
|
|
|
elif isinstance(value, list):
|
|
|
|
new_dict[key] = value.copy()
|
|
|
|
return new_dict
|
|
|
|
|
2024-11-11 14:41:08 +00:00
|
|
|
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
|
|
|
if copy_dict1:
|
|
|
|
merged_dict = copy_nested_dicts(dict1)
|
|
|
|
else:
|
|
|
|
merged_dict = dict1
|
2024-11-03 12:22:48 +00:00
|
|
|
for key, value in dict2.items():
|
|
|
|
if isinstance(value, dict):
|
|
|
|
curr_value = merged_dict.setdefault(key, {})
|
|
|
|
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
|
|
|
elif isinstance(value, list):
|
|
|
|
merged_dict.setdefault(key, []).extend(value)
|
|
|
|
else:
|
|
|
|
merged_dict[key] = value
|
|
|
|
return merged_dict
|