Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options

This commit is contained in:
Jedrzej Kosinski 2024-11-11 08:41:08 -06:00
parent 9dde713347
commit 638c4086a3
7 changed files with 77 additions and 19 deletions

View File

@ -118,6 +118,14 @@ class ControlBase:
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original

View File

@ -54,7 +54,7 @@ class Hook:
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initalize_timesteps(model)
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
@ -193,7 +193,7 @@ class AddModelsHook(Hook):
def add_hook_models(self, model: 'ModelPatcher'):
pass
class AddCallbackHook(Hook):
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.AddCallback)
self.key = key
@ -202,7 +202,7 @@ class AddCallbackHook(Hook):
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddCallbackHook = super().clone(subtype)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
@ -227,7 +227,7 @@ class SetInjectionsHook(Hook):
def add_hook_injections(self, model: 'ModelPatcher'):
pass
class AddWrapperHook(Hook):
class WrapperHook(Hook):
def __init__(self, key: str=None, wrapper: Callable=None):
super().__init__(hook_type=EnumHookType.AddWrapper)
self.key = key
@ -236,7 +236,7 @@ class AddWrapperHook(Hook):
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddWrapperHook = super().clone(subtype)
c: WrapperHook = super().clone(subtype)
c.key = self.key
c.wrapper = self.wrapper
return c
@ -268,7 +268,10 @@ class HookGroup:
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
hook_kf = hook_kf.clone()
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
@ -330,7 +333,7 @@ class HookGroup:
def reset(self):
for hook in self.hooks:
hook.hook_keyframe.reset()
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
@ -414,11 +417,11 @@ class HookKeyframeGroup:
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe)
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initalize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)

View File

@ -990,12 +990,14 @@ class ModelPatcher:
combined_patches[key] = current_patches
return combined_patches
def apply_hooks(self, hooks: comfy.hooks.HookGroup):
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks:
return
return {}
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)
return {}
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():

View File

@ -139,8 +139,11 @@ def copy_nested_dicts(input_dict: dict):
new_dict[key] = value.copy()
return new_dict
def merge_nested_dicts(dict1: dict, dict2: dict):
merged_dict = copy_nested_dicts(dict1)
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
if copy_dict1:
merged_dict = copy_nested_dicts(dict1)
else:
merged_dict = dict1
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})

View File

@ -1,9 +1,11 @@
from __future__ import annotations
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
import torch
import collections
from comfy import model_management
@ -248,8 +250,6 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
to_batch = batch_amount
break
model.current_patcher.apply_hooks(hooks=hooks)
input_x = []
mult = []
c = []
@ -275,11 +275,14 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {}
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
@ -774,6 +777,34 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
# determine which ControlNets have extra_hooks that should be combined with normal hooks
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
for k in conds:
for kk in conds[k]:
if 'control' in kk:
control: 'ControlBase' = kk['control']
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
to_replace = hook_replacement.setdefault((control, hooks), [])
to_replace.append(kk)
# if nothing to replace, do nothing
if len(hook_replacement) == 0:
return
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
# on the cond dicts
for key, conds_to_modify in hook_replacement.items():
control = key[0]
hooks = key[1]
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
# if combined hooks are not None, set as new hooks for all relevant conds
if hooks is not None:
for cond in conds_to_modify:
cond['hooks'] = hooks
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
@ -842,6 +873,7 @@ class CFGGuider:
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
preprocess_conds_hooks(self.conds)
try:
orig_model_options = self.model_options

View File

@ -4,6 +4,7 @@ from enum import Enum
import logging
from comfy import model_management
from comfy.utils import ProgressBar
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
@ -119,7 +120,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]=None):
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]=None, show_pbar=True):
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
@ -132,6 +133,10 @@ class CLIP:
self.load_model()
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
hooks_keyframes = scheduled_opts[1]
@ -150,6 +155,9 @@ class CLIP:
if add_dict is not None:
pooled_dict.update(add_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled

View File

@ -243,7 +243,9 @@ class SetClipHooks:
if hooks is not None:
clip = clip.clone()
clip.use_clip_schedule = schedule_clip
clip.patcher.forced_hooks = hooks
clip.patcher.forced_hooks = hooks.clone()
if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
return (clip,)