For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes

This commit is contained in:
Jedrzej Kosinski 2024-11-18 12:25:17 -06:00
parent 365170af95
commit 9fe3db4c3a
1 changed files with 9 additions and 3 deletions

View File

@ -128,7 +128,7 @@ class CLIP:
pooled_dict["hooks"] = self.apply_hooks_to_conds
return pooled_dict
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]=None, show_pbar=True):
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True):
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
@ -147,6 +147,13 @@ class CLIP:
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
@ -160,8 +167,7 @@ class CLIP:
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
if add_dict is not None:
pooled_dict.update(add_dict)
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])