Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case

This commit is contained in:
Jedrzej Kosinski 2024-11-20 18:57:18 -06:00
parent 35016983fb
commit d38c535771
2 changed files with 19 additions and 2 deletions

View File

@ -886,6 +886,7 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
curr_t = t[0]
reset_current_hooks = False
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
@ -895,11 +896,13 @@ class ModelPatcher:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
self.current_hooks = None
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
self.restore_hook_patches()
@ -1020,7 +1023,7 @@ class ModelPatcher:
def clear_cached_hook_weights(self):
self.cached_hook_patches.clear()
self.current_hooks = None
self.patch_hooks(None)
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:

View File

@ -805,6 +805,15 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
for cond in conds_to_modify:
cond['hooks'] = hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
for k in conds:
for kk in conds[k]:
hooks_set.add(kk.get('hooks', None))
return len(hooks_set)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
@ -878,6 +887,10 @@ class CFGGuider:
try:
orig_model_options = self.model_options
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
orig_hook_mode = self.model_patcher.hook_mode
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
@ -887,6 +900,7 @@ class CFGGuider:
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
del self.conds