ComfyUI/comfy_extras/nodes_hooks.py

433 lines
16 KiB
Python

from typing import TYPE_CHECKING, Dict, List, Tuple
import torch
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.sd import CLIP
import comfy.hooks
import comfy.sd
import folder_paths
###########################################
# Mask, Combine, and Hook Conditioning
#------------------------------------------
class PairConditioningSetProperties:
NodeId = 'PairConditioningSetProperties'
NodeName = 'Pair Cond Set Props'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive_NEW": ("CONDITIONING", ),
"negative_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"opt_mask": ("MASK", ),
"opt_hooks": ("HOOKS",),
"opt_timesteps": ("TIMESTEPS_RANGE",),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_properties"
def set_properties(self, positive_NEW, negative_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.Hook=None, opt_timesteps: Tuple=None):
final_positive, final_negative = comfy.hooks.set_mask_conds(conds=[positive_NEW, negative_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_positive, final_negative)
class ConditioningSetProperties:
NodeId = 'ConditioningSetProperties'
NodeName = 'Cond Set Props'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond_NEW": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
},
"optional": {
"opt_mask": ("MASK", ),
"opt_hooks": ("HOOKS",),
"opt_timesteps": ("TIMESTEPS_RANGE",),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_properties"
def set_properties(self, cond_NEW,
strength: float, set_cond_area: str,
opt_mask: torch.Tensor=None, opt_hooks: comfy.hooks.Hook=None, opt_timesteps: Tuple=None):
(final_cond,) = comfy.hooks.set_mask_conds(conds=[cond_NEW],
strength=strength, set_cond_area=set_cond_area,
opt_mask=opt_mask, opt_hooks=opt_hooks, opt_timestep_range=opt_timesteps)
return (final_cond,)
class PairConditioningCombine:
NodeId = 'PairConditioningCombine'
NodeName = 'Pair Cond Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive_A": ("CONDITIONING",),
"negative_A": ("CONDITIONING",),
"positive_B": ("CONDITIONING",),
"negative_B": ("CONDITIONING",),
},
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "combine"
def combine(self, positive_A, negative_A, positive_B, negative_B):
final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
return (final_positive, final_negative,)
class PairConditioningSetDefaultAndCombine:
NodeId = 'PairConditioningSetDefaultCombine'
NodeName = 'Pair Cond Set Default Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"positive_DEFAULT": ("CONDITIONING",),
"negative_DEFAULT": ("CONDITIONING",),
},
"optional": {
"opt_hooks": ("HOOKS",),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond pair"
FUNCTION = "set_default_and_combine"
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
opt_hooks: comfy.hooks.HookGroup=None):
final_positive, final_negative = comfy.hooks.set_default_and_combine_conds(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
opt_hooks=opt_hooks)
return (final_positive, final_negative)
class ConditioningSetDefaultAndCombine:
NodeId = 'ConditioningSetDefaultCombine'
NodeName = 'Cond Set Default Combine'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING",),
"cond_DEFAULT": ("CONDITIONING",),
},
"optional": {
"opt_hooks": ("HOOKS",),
}
}
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_default_and_combine"
def append_and_combine(self, cond, cond_DEFAULT,
opt_hooks: comfy.hooks.HookGroup=None):
(final_conditioning,) = comfy.hooks.set_default_and_combine_conds(conds=[cond], new_conds=[cond_DEFAULT],
opt_hooks=opt_hooks)
return (final_conditioning,)
#------------------------------------------
###########################################
###########################################
# Register Hooks
#------------------------------------------
class RegisterHookLora:
NodeId = 'RegisterHookLora'
NodeName = 'Register Hook LoRA'
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "HOOKS")
CATEGORY = "advanced/hooks/register"
FUNCTION = "register_lora"
def register_lora(self, model: 'ModelPatcher', clip: 'CLIP', lora_name: str,
strength_model: float, strength_clip: float):
if strength_model == 0 and strength_clip == 0:
return (model, clip, None)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
hook = comfy.hooks.Hook()
hook_group = comfy.hooks.HookGroup()
hook_group.add(hook)
model_lora, clip_lora = comfy.hooks.load_hook_lora_for_models(model=model, clip=clip, lora=lora, hook=hook,
strength_model=strength_model, strength_clip=strength_clip)
return (model_lora, clip_lora, hook_group)
class RegisterHookLoraModelOnly(RegisterHookLora):
NodeId = 'RegisterHookLoraModelOnly'
NodeName = 'Register Hook LoRA (MO)'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL", "HOOKS")
CATEGORY = "advanced/hooks/register"
FUNCTION = "register_lora_model_only"
def register_lora_model_only(self, model: 'ModelPatcher', lora_name: str, strength_model: float):
model_lora, _, hooks = self.register_lora(model=model, clip=None, lora_name=lora_name,
strength_model=strength_model, strength_clip=0)
return (model_lora, hooks)
class RegisterHookModelAsLora:
NodeId = 'RegisterHookModelAsLora'
NodeName = 'Register Hook Model as LoRA'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "HOOKS")
CATEGORY = "advanced/hooks/register"
FUNCTION = "register_model_as_lora"
def register_model_as_lora(self, model: 'ModelPatcher', clip: 'CLIP', ckpt_name: str,
strength_model: float, strength_clip: float):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model_loaded = out[0]
clip_loaded = out[1]
hook = comfy.hooks.Hook()
hook_group = comfy.hooks.HookGroup()
hook_group.add(hook)
model_lora, clip_lora = comfy.hooks.load_hook_model_as_lora_for_models(model=model, clip=clip,
model_loaded=model_loaded, clip_loaded=clip_loaded,
hook=hook,
strength_model=strength_model, strength_clip=strength_clip)
return (model_lora, clip_lora, hook_group)
class RegisterHookModelAsLoraModelOnly:
NodeId = 'RegisterHookModelAsLoraModelOnly'
NodeName = 'Register Hook Model as LoRA (MO)'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL", "HOOKS")
CATEGORY = "advanced/hooks/register"
FUNCTION = "register_model_as_lora_model_only"
def register_model_as_lora_model_only(self, model: 'ModelPatcher', ckpt_name: str, strength_model: float):
model_lora, _, hooks = RegisterHookModelAsLora.register_model_as_lora(self, model=model, clip=None, ckpt_name=ckpt_name,
strength_model=strength_model, strength_clip=0)
return (model_lora, hooks)
#------------------------------------------
###########################################
###########################################
# Schedule Hooks
#------------------------------------------
#------------------------------------------
###########################################
class SetModelHooksOnCond:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"conditioning": ("CONDITIONING",),
"hooks": ("HOOKS",),
},
}
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/manual"
FUNCTION = "attach_hook"
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
###########################################
# Combine Hooks
#------------------------------------------
class CombineHooks:
NodeId = 'CombineHooks2'
NodeName = 'Combine Hooks [2]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
}
}
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
class CombineHooksFour:
NodeId = 'CombineHooks4'
NodeName = 'Combine Hooks [4]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
"hooks_C": ("HOOKS",),
"hooks_D": ("HOOKS",),
}
}
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None,
hooks_C: comfy.hooks.HookGroup=None,
hooks_D: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
class CombineHooksEight:
NodeId = 'CombineHooks8'
NodeName = 'Combine Hooks [8]'
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"hooks_A": ("HOOKS",),
"hooks_B": ("HOOKS",),
"hooks_C": ("HOOKS",),
"hooks_D": ("HOOKS",),
"hooks_E": ("HOOKS",),
"hooks_F": ("HOOKS",),
"hooks_G": ("HOOKS",),
"hooks_H": ("HOOKS",),
}
}
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
FUNCTION = "combine_hooks"
def combine_hooks(self,
hooks_A: comfy.hooks.HookGroup=None,
hooks_B: comfy.hooks.HookGroup=None,
hooks_C: comfy.hooks.HookGroup=None,
hooks_D: comfy.hooks.HookGroup=None,
hooks_E: comfy.hooks.HookGroup=None,
hooks_F: comfy.hooks.HookGroup=None,
hooks_G: comfy.hooks.HookGroup=None,
hooks_H: comfy.hooks.HookGroup=None):
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
#------------------------------------------
###########################################
node_list = [
# Register
RegisterHookLora,
RegisterHookLoraModelOnly,
RegisterHookModelAsLora,
RegisterHookModelAsLoraModelOnly,
# Combine
CombineHooks,
CombineHooksFour,
CombineHooksEight,
# Attach
ConditioningSetProperties,
PairConditioningSetProperties,
ConditioningSetDefaultAndCombine,
PairConditioningSetDefaultAndCombine,
PairConditioningCombine
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName