Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
This commit is contained in:
parent
9dde713347
commit
638c4086a3
|
@ -118,6 +118,14 @@ class ControlBase:
|
|||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
out.append(self.extra_hooks)
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
|
|
|
@ -54,7 +54,7 @@ class Hook:
|
|||
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
self.reset()
|
||||
self.hook_keyframe.initalize_timesteps(model)
|
||||
self.hook_keyframe.initialize_timesteps(model)
|
||||
|
||||
def reset(self):
|
||||
self.hook_keyframe.reset()
|
||||
|
@ -193,7 +193,7 @@ class AddModelsHook(Hook):
|
|||
def add_hook_models(self, model: 'ModelPatcher'):
|
||||
pass
|
||||
|
||||
class AddCallbackHook(Hook):
|
||||
class CallbackHook(Hook):
|
||||
def __init__(self, key: str=None, callback: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.AddCallback)
|
||||
self.key = key
|
||||
|
@ -202,7 +202,7 @@ class AddCallbackHook(Hook):
|
|||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddCallbackHook = super().clone(subtype)
|
||||
c: CallbackHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.callback = self.callback
|
||||
return c
|
||||
|
@ -227,7 +227,7 @@ class SetInjectionsHook(Hook):
|
|||
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||
pass
|
||||
|
||||
class AddWrapperHook(Hook):
|
||||
class WrapperHook(Hook):
|
||||
def __init__(self, key: str=None, wrapper: Callable=None):
|
||||
super().__init__(hook_type=EnumHookType.AddWrapper)
|
||||
self.key = key
|
||||
|
@ -236,7 +236,7 @@ class AddWrapperHook(Hook):
|
|||
def clone(self, subtype: Callable=None):
|
||||
if subtype is None:
|
||||
subtype = type(self)
|
||||
c: AddWrapperHook = super().clone(subtype)
|
||||
c: WrapperHook = super().clone(subtype)
|
||||
c.key = self.key
|
||||
c.wrapper = self.wrapper
|
||||
return c
|
||||
|
@ -268,7 +268,10 @@ class HookGroup:
|
|||
return c
|
||||
|
||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||
hook_kf = hook_kf.clone()
|
||||
if hook_kf is None:
|
||||
hook_kf = HookKeyframeGroup()
|
||||
else:
|
||||
hook_kf = hook_kf.clone()
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe = hook_kf
|
||||
|
||||
|
@ -330,7 +333,7 @@ class HookGroup:
|
|||
|
||||
def reset(self):
|
||||
for hook in self.hooks:
|
||||
hook.hook_keyframe.reset()
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
|
@ -414,11 +417,11 @@ class HookKeyframeGroup:
|
|||
def clone(self):
|
||||
c = HookKeyframeGroup()
|
||||
for keyframe in self.keyframes:
|
||||
c.keyframes.append(keyframe)
|
||||
c.keyframes.append(keyframe.clone())
|
||||
c._set_first_as_current()
|
||||
return c
|
||||
|
||||
def initalize_timesteps(self, model: 'BaseModel'):
|
||||
def initialize_timesteps(self, model: 'BaseModel'):
|
||||
for keyframe in self.keyframes:
|
||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||
|
||||
|
|
|
@ -990,12 +990,14 @@ class ModelPatcher:
|
|||
combined_patches[key] = current_patches
|
||||
return combined_patches
|
||||
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
def apply_hooks(self, hooks: comfy.hooks.HookGroup, model_options: dict=None):
|
||||
# TODO: return transformer_options dict with any additions from hooks
|
||||
if self.current_hooks == hooks:
|
||||
return
|
||||
return {}
|
||||
self.patch_hooks(hooks=hooks)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
||||
callback(self, hooks)
|
||||
return {}
|
||||
|
||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||
with self.use_ejected():
|
||||
|
|
|
@ -139,8 +139,11 @@ def copy_nested_dicts(input_dict: dict):
|
|||
new_dict[key] = value.copy()
|
||||
return new_dict
|
||||
|
||||
def merge_nested_dicts(dict1: dict, dict2: dict):
|
||||
merged_dict = copy_nested_dicts(dict1)
|
||||
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
if copy_dict1:
|
||||
merged_dict = copy_nested_dicts(dict1)
|
||||
else:
|
||||
merged_dict = dict1
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import annotations
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
import torch
|
||||
import collections
|
||||
from comfy import model_management
|
||||
|
@ -248,8 +250,6 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
model.current_patcher.apply_hooks(hooks=hooks)
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
|
@ -275,11 +275,14 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = {}
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
|
@ -774,6 +777,34 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
|||
|
||||
return conds
|
||||
|
||||
|
||||
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
||||
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
||||
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
control: 'ControlBase' = kk['control']
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
||||
to_replace = hook_replacement.setdefault((control, hooks), [])
|
||||
to_replace.append(kk)
|
||||
# if nothing to replace, do nothing
|
||||
if len(hook_replacement) == 0:
|
||||
return
|
||||
|
||||
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
||||
# on the cond dicts
|
||||
for key, conds_to_modify in hook_replacement.items():
|
||||
control = key[0]
|
||||
hooks = key[1]
|
||||
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
||||
# if combined hooks are not None, set as new hooks for all relevant conds
|
||||
if hooks is not None:
|
||||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
|
@ -842,6 +873,7 @@ class CFGGuider:
|
|||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
preprocess_conds_hooks(self.conds)
|
||||
|
||||
try:
|
||||
orig_model_options = self.model_options
|
||||
|
|
10
comfy/sd.py
10
comfy/sd.py
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
|
@ -119,7 +120,7 @@ class CLIP:
|
|||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]=None):
|
||||
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]=None, show_pbar=True):
|
||||
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
|
||||
all_hooks = self.patcher.forced_hooks
|
||||
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
|
||||
|
@ -132,6 +133,10 @@ class CLIP:
|
|||
|
||||
self.load_model()
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
|
@ -150,6 +155,9 @@ class CLIP:
|
|||
if add_dict is not None:
|
||||
pooled_dict.update(add_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
|
|
|
@ -243,7 +243,9 @@ class SetClipHooks:
|
|||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
clip.use_clip_schedule = schedule_clip
|
||||
clip.patcher.forced_hooks = hooks
|
||||
clip.patcher.forced_hooks = hooks.clone()
|
||||
if not clip.use_clip_schedule:
|
||||
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
||||
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
|
||||
return (clip,)
|
||||
|
||||
|
|
Loading…
Reference in New Issue