diff --git a/comfy/controlnet.py b/comfy/controlnet.py index b8e27c71..12e5f16c 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -45,6 +45,7 @@ class ControlBase: self.timestep_range = None self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' + self.extra_args = {} if device is None: device = comfy.model_management.get_torch_device() @@ -90,6 +91,7 @@ class ControlBase: c.compression_ratio = self.compression_ratio c.upscale_algorithm = self.upscale_algorithm c.latent_format = self.latent_format + c.extra_args = self.extra_args.copy() c.vae = self.vae def inference_memory_requirements(self, dtype): @@ -135,6 +137,10 @@ class ControlBase: o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue return out + def set_extra_arg(self, argument, value=None): + self.extra_args[argument] = value + + class ControlNet(ControlBase): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) @@ -191,7 +197,7 @@ class ControlNet(ControlBase): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args) return self.control_merge(control, control_prev, output_dtype) def copy(self): diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py new file mode 100644 index 00000000..e550436b --- /dev/null +++ b/comfy_extras/nodes_controlnet.py @@ -0,0 +1,37 @@ + +UNION_CONTROLNET_TYPES = {"auto": -1, + "openpose": 0, + "depth": 1, + "hed/pidi/scribble/ted": 2, + "canny/lineart/anime_lineart/mlsd": 3, + "normal": 4, + "segment": 5, + "tile": 6, + "repaint": 7, + } + +class SetUnionControlNetType: + @classmethod + def INPUT_TYPES(s): + return {"required": {"control_net": ("CONTROL_NET", ), + "type": (list(UNION_CONTROLNET_TYPES.keys()),) + }} + + CATEGORY = "conditioning" + RETURN_TYPES = ("CONTROL_NET",) + + FUNCTION = "set_controlnet_type" + + def set_controlnet_type(self, control_net, type): + control_net = control_net.copy() + type_number = UNION_CONTROLNET_TYPES[type] + if type_number >= 0: + control_net.set_extra_arg("control_type", [type_number]) + else: + control_net.set_extra_arg("control_type", []) + + return (control_net,) + +NODE_CLASS_MAPPINGS = { + "SetUnionControlNetType": SetUnionControlNetType, +} diff --git a/nodes.py b/nodes.py index 998b316b..89a2f21d 100644 --- a/nodes.py +++ b/nodes.py @@ -2036,6 +2036,7 @@ def init_builtin_extra_nodes(): "nodes_audio.py", "nodes_sd3.py", "nodes_gits.py", + "nodes_controlnet.py", ] import_failed = []