Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate

This commit is contained in:
Jedrzej Kosinski 2024-11-16 16:40:45 -06:00
parent bcc6a22178
commit e177149ae4
3 changed files with 59 additions and 24 deletions

View File

@ -261,8 +261,9 @@ class HookGroup:
def clone_and_combine(self, other: 'HookGroup'):
c = self.clone()
for hook in other.hooks:
c.add(hook.clone())
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
@ -577,10 +578,44 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
print(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def set_hooks_for_conditioning(cond, hooks: HookGroup):
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
if hooks is None:
return cond
return conditioning_set_values(cond, {'hooks': hooks})
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
@ -612,26 +647,26 @@ def combine_with_new_conds(conds: list, new_conds: list):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_mask_conds(conds: list, strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: HookGroup=None, opt_timestep_range: tuple[float,float]=None):
masked_conds = []
def set_conds_props(conds: list, strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: HookGroup=None, opt_timestep_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, opt_hooks)
c = set_hooks_for_conditioning(c, opt_hooks, append_hooks=append_hooks)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=opt_mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=opt_timestep_range)
# finally, apply mask to conditioning and store
masked_conds.append(c)
return masked_conds
final_conds.append(c)
return final_conds
def set_mask_and_combine_conds(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
opt_mask: torch.Tensor=None, opt_hooks: HookGroup=None, opt_timestep_range: tuple[float,float]=None):
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
opt_mask: torch.Tensor=None, opt_hooks: HookGroup=None, opt_timestep_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, opt_hooks)
masked_c = set_hooks_for_conditioning(masked_c, opt_hooks, append_hooks=append_hooks)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=opt_mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
@ -640,12 +675,12 @@ def set_mask_and_combine_conds(conds: list, new_conds: list, strength: float=1.0
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_and_combine_conds(conds: list, new_conds: list,
opt_hooks: HookGroup=None, opt_timestep_range: tuple[float,float]=None):
def set_default_conds_and_combine(conds: list, new_conds: list,
opt_hooks: HookGroup=None, opt_timestep_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, opt_hooks)
new_c = set_hooks_for_conditioning(new_c, opt_hooks, append_hooks=append_hooks)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present

View File

@ -968,7 +968,7 @@ class ModelPatcher:
combined_patches[key] = current_patches
return combined_patches
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None, force_apply=False):
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}

View File

@ -42,7 +42,7 @@ class PairConditioningSetProperties:
def set_properties(self, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
final_positive, final_negative = comfy.hooks.set_mask_conds(conds=[positive_NEW, negative_NEW],
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_positive, final_negative)
@ -76,7 +76,7 @@ class PairConditioningSetPropertiesAndCombine:
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_positive, final_negative)
@ -106,7 +106,7 @@ class ConditioningSetProperties:
def set_properties(self, cond_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
(final_cond,) = comfy.hooks.set_mask_conds(conds=[cond_NEW],
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_cond,)
@ -137,7 +137,7 @@ class ConditioningSetPropertiesAndCombine:
def set_properties(self, cond, cond_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.HookGroup=None, opt_timesteps: tuple=None):
(final_cond,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond], new_conds=[cond_NEW],
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_cond,)
@ -162,7 +162,7 @@ class PairConditioningCombine:
FUNCTION = "combine"
def combine(self, positive_A, negative_A, positive_B, negative_B):
final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
return (final_positive, final_negative,)
class PairConditioningSetDefaultAndCombine:
@ -189,7 +189,7 @@ class PairConditioningSetDefaultAndCombine:
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
opt_hooks: comfy.hooks.HookGroup=None):
final_positive, final_negative = comfy.hooks.set_default_and_combine_conds(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
opt_hooks=opt_hooks)
return (final_positive, final_negative)
@ -214,7 +214,7 @@ class ConditioningSetDefaultAndCombine:
def set_default_and_combine(self, cond, cond_DEFAULT,
opt_hooks: comfy.hooks.HookGroup=None):
(final_conditioning,) = comfy.hooks.set_default_and_combine_conds(conds=[cond], new_conds=[cond_DEFAULT],
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
opt_hooks=opt_hooks)
return (final_conditioning,)