Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
This commit is contained in:
parent
35016983fb
commit
d38c535771
|
@ -886,6 +886,7 @@ class ModelPatcher:
|
|||
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
|
@ -895,11 +896,13 @@ class ModelPatcher:
|
|||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
self.current_hooks = None
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
|
||||
self.restore_hook_patches()
|
||||
|
@ -1020,7 +1023,7 @@ class ModelPatcher:
|
|||
|
||||
def clear_cached_hook_weights(self):
|
||||
self.cached_hook_patches.clear()
|
||||
self.current_hooks = None
|
||||
self.patch_hooks(None)
|
||||
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
|
|
|
@ -805,6 +805,15 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
|||
for cond in conds_to_modify:
|
||||
cond['hooks'] = hooks
|
||||
|
||||
|
||||
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
||||
hooks_set = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
hooks_set.add(kk.get('hooks', None))
|
||||
return len(hooks_set)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher: 'ModelPatcher' = model_patcher
|
||||
|
@ -878,6 +887,10 @@ class CFGGuider:
|
|||
try:
|
||||
orig_model_options = self.model_options
|
||||
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
||||
orig_hook_mode = self.model_patcher.hook_mode
|
||||
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
||||
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.outer_sample,
|
||||
|
@ -887,6 +900,7 @@ class CFGGuider:
|
|||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_options = orig_model_options
|
||||
self.model_patcher.hook_mode = orig_hook_mode
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
|
|
Loading…
Reference in New Issue