from typing import TYPE_CHECKING, Dict, List, Tuple import torch if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP import comfy.hooks import comfy.sd import folder_paths ########################################### # Mask, Combine, and Hook Conditioning #------------------------------------------ class PairConditioningSetProperties: NodeId = 'PairConditioningSetProperties' NodeName = 'Pair Cond Set Props' @classmethod def INPUT_TYPES(s): return { "required": { "positive_NEW": ("CONDITIONING", ), "negative_NEW": ("CONDITIONING", ), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "set_cond_area": (["default", "mask bounds"],), }, "optional": { "opt_mask": ("MASK", ), "opt_hooks": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), } } RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") CATEGORY = "advanced/hooks/cond pair" FUNCTION = "set_properties" def set_properties(self, positive_NEW, negative_NEW, strength: float, set_cond_area: str, opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.Hook=None, opt_timesteps: Tuple=None): final_positive, final_negative = comfy.hooks.set_mask_conds(conds=[positive_NEW, negative_NEW], strength=strength, set_cond_area=set_cond_area, opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps) return (final_positive, final_negative) class ConditioningSetProperties: NodeId = 'ConditioningSetProperties' NodeName = 'Cond Set Props' @classmethod def INPUT_TYPES(s): return { "required": { "cond_NEW": ("CONDITIONING", ), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "set_cond_area": (["default", "mask bounds"],), }, "optional": { "opt_mask": ("MASK", ), "opt_hooks": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), } } RETURN_TYPES = ("CONDITIONING",) RETURN_NAMES = ("positive", "negative") CATEGORY = "advanced/hooks/cond single" FUNCTION = "set_properties" def set_properties(self, cond_NEW, strength: float, set_cond_area: str, opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.Hook=None, opt_timesteps: Tuple=None): (final_cond,) = comfy.hooks.set_mask_conds(conds=[cond_NEW], strength=strength, set_cond_area=set_cond_area, opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps) return (final_cond,) class PairConditioningCombine: NodeId = 'PairConditioningCombine' NodeName = 'Pair Cond Combine' @classmethod def INPUT_TYPES(s): return { "required": { "positive_A": ("CONDITIONING",), "negative_A": ("CONDITIONING",), "positive_B": ("CONDITIONING",), "negative_B": ("CONDITIONING",), }, } RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") CATEGORY = "advanced/hooks/cond pair" FUNCTION = "combine" def combine(self, positive_A, negative_A, positive_B, negative_B): final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) return (final_positive, final_negative,) class PairConditioningSetDefaultAndCombine: NodeId = 'PairConditioningSetDefaultCombine' NodeName = 'Pair Cond Set Default Combine' @classmethod def INPUT_TYPES(s): return { "required": { "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "positive_DEFAULT": ("CONDITIONING",), "negative_DEFAULT": ("CONDITIONING",), }, "optional": { "opt_hooks": ("HOOKS",), } } RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") CATEGORY = "advanced/hooks/cond pair" FUNCTION = "set_default_and_combine" def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, opt_hooks: comfy.hooks.HookGroup=None): final_positive, final_negative = comfy.hooks.set_default_and_combine_conds(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], opt_hooks=opt_hooks) return (final_positive, final_negative) class ConditioningSetDefaultAndCombine: NodeId = 'ConditioningSetDefaultCombine' NodeName = 'Cond Set Default Combine' @classmethod def INPUT_TYPES(s): return { "required": { "cond": ("CONDITIONING",), "cond_DEFAULT": ("CONDITIONING",), }, "optional": { "opt_hooks": ("HOOKS",), } } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "advanced/hooks/cond single" FUNCTION = "set_default_and_combine" def append_and_combine(self, cond, cond_DEFAULT, opt_hooks: comfy.hooks.HookGroup=None): (final_conditioning,) = comfy.hooks.set_default_and_combine_conds(conds=[cond], new_conds=[cond_DEFAULT], opt_hooks=opt_hooks) return (final_conditioning,) #------------------------------------------ ########################################### ########################################### # Register Hooks #------------------------------------------ class RegisterHookLora: NodeId = 'RegisterHookLora' NodeName = 'Register Hook LoRA' def __init__(self): self.loaded_lora = None @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "clip": ("CLIP",), "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), } } RETURN_TYPES = ("MODEL", "CLIP", "HOOKS") CATEGORY = "advanced/hooks/register" FUNCTION = "register_lora" def register_lora(self, model: 'ModelPatcher', clip: 'CLIP', lora_name: str, strength_model: float, strength_clip: float): if strength_model == 0 and strength_clip == 0: return (model, clip, None) lora_path = folder_paths.get_full_path("loras", lora_name) lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] else: temp = self.loaded_lora self.loaded_lora = None del temp if lora is None: lora = comfy.utils.load_torch_file(lora_path, safe_load=True) self.loaded_lora = (lora_path, lora) hook = comfy.hooks.Hook() hook_group = comfy.hooks.HookGroup() hook_group.add(hook) model_lora, clip_lora = comfy.hooks.load_hook_lora_for_models(model=model, clip=clip, lora=lora, hook=hook, strength_model=strength_model, strength_clip=strength_clip) return (model_lora, clip_lora, hook_group) class RegisterHookLoraModelOnly(RegisterHookLora): NodeId = 'RegisterHookLoraModelOnly' NodeName = 'Register Hook LoRA (MO)' @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "clip": ("CLIP",), "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), } } RETURN_TYPES = ("MODEL", "HOOKS") CATEGORY = "advanced/hooks/register" FUNCTION = "register_lora_model_only" def register_lora_model_only(self, model: 'ModelPatcher', lora_name: str, strength_model: float): model_lora, _, hooks = self.register_lora(model=model, clip=None, lora_name=lora_name, strength_model=strength_model, strength_clip=0) return (model_lora, hooks) class RegisterHookModelAsLora: NodeId = 'RegisterHookModelAsLora' NodeName = 'Register Hook Model as LoRA' @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "clip": ("CLIP",), "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), } } RETURN_TYPES = ("MODEL", "CLIP", "HOOKS") CATEGORY = "advanced/hooks/register" FUNCTION = "register_model_as_lora" def register_model_as_lora(self, model: 'ModelPatcher', clip: 'CLIP', ckpt_name: str, strength_model: float, strength_clip: float): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) model_loaded = out[0] clip_loaded = out[1] hook = comfy.hooks.Hook() hook_group = comfy.hooks.HookGroup() hook_group.add(hook) model_lora, clip_lora = comfy.hooks.load_hook_model_as_lora_for_models(model=model, clip=clip, model_loaded=model_loaded, clip_loaded=clip_loaded, hook=hook, strength_model=strength_model, strength_clip=strength_clip) return (model_lora, clip_lora, hook_group) class RegisterHookModelAsLoraModelOnly: NodeId = 'RegisterHookModelAsLoraModelOnly' NodeName = 'Register Hook Model as LoRA (MO)' @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), } } RETURN_TYPES = ("MODEL", "HOOKS") CATEGORY = "advanced/hooks/register" FUNCTION = "register_model_as_lora_model_only" def register_model_as_lora_model_only(self, model: 'ModelPatcher', ckpt_name: str, strength_model: float): model_lora, _, hooks = RegisterHookModelAsLora.register_model_as_lora(self, model=model, clip=None, ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0) return (model_lora, hooks) #------------------------------------------ ########################################### ########################################### # Schedule Hooks #------------------------------------------ #------------------------------------------ ########################################### class SetModelHooksOnCond: @classmethod def INPUT_TYPES(s): return { "required": { "conditioning": ("CONDITIONING",), "hooks": ("HOOKS",), }, } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "advanced/hooks/manual" FUNCTION = "attach_hook" def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup): return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),) ########################################### # Combine Hooks #------------------------------------------ class CombineHooks: NodeId = 'CombineHooks2' NodeName = 'Combine Hooks [2]' @classmethod def INPUT_TYPES(s): return { "required": { }, "optional": { "hooks_A": ("HOOKS",), "hooks_B": ("HOOKS",), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/combine" FUNCTION = "combine_hooks" def combine_hooks(self, hooks_A: comfy.hooks.HookGroup=None, hooks_B: comfy.hooks.HookGroup=None): candidates = [hooks_A, hooks_B] return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) class CombineHooksFour: NodeId = 'CombineHooks4' NodeName = 'Combine Hooks [4]' @classmethod def INPUT_TYPES(s): return { "required": { }, "optional": { "hooks_A": ("HOOKS",), "hooks_B": ("HOOKS",), "hooks_C": ("HOOKS",), "hooks_D": ("HOOKS",), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/combine" FUNCTION = "combine_hooks" def combine_hooks(self, hooks_A: comfy.hooks.HookGroup=None, hooks_B: comfy.hooks.HookGroup=None, hooks_C: comfy.hooks.HookGroup=None, hooks_D: comfy.hooks.HookGroup=None): candidates = [hooks_A, hooks_B, hooks_C, hooks_D] return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) class CombineHooksEight: NodeId = 'CombineHooks8' NodeName = 'Combine Hooks [8]' @classmethod def INPUT_TYPES(s): return { "required": { }, "optional": { "hooks_A": ("HOOKS",), "hooks_B": ("HOOKS",), "hooks_C": ("HOOKS",), "hooks_D": ("HOOKS",), "hooks_E": ("HOOKS",), "hooks_F": ("HOOKS",), "hooks_G": ("HOOKS",), "hooks_H": ("HOOKS",), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/combine" FUNCTION = "combine_hooks" def combine_hooks(self, hooks_A: comfy.hooks.HookGroup=None, hooks_B: comfy.hooks.HookGroup=None, hooks_C: comfy.hooks.HookGroup=None, hooks_D: comfy.hooks.HookGroup=None, hooks_E: comfy.hooks.HookGroup=None, hooks_F: comfy.hooks.HookGroup=None, hooks_G: comfy.hooks.HookGroup=None, hooks_H: comfy.hooks.HookGroup=None): candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H] return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) #------------------------------------------ ########################################### node_list = [ # Register RegisterHookLora, RegisterHookLoraModelOnly, RegisterHookModelAsLora, RegisterHookModelAsLoraModelOnly, # Combine CombineHooks, CombineHooksFour, CombineHooksEight, # Attach ConditioningSetProperties, PairConditioningSetProperties, ConditioningSetDefaultAndCombine, PairConditioningSetDefaultAndCombine, PairConditioningCombine ] NODE_CLASS_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {} for node in node_list: NODE_CLASS_MAPPINGS[node.NodeId] = node NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName