Add SetUnionControlNetType to set the type of the union controlnet model.

This commit is contained in:
comfyanonymous 2024-07-16 17:01:40 -04:00
parent 821f93872e
commit 8270c62530
3 changed files with 45 additions and 1 deletions

View File

@ -45,6 +45,7 @@ class ControlBase:
self.timestep_range = None self.timestep_range = None
self.compression_ratio = 8 self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact' self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@ -90,6 +91,7 @@ class ControlBase:
c.compression_ratio = self.compression_ratio c.compression_ratio = self.compression_ratio
c.upscale_algorithm = self.upscale_algorithm c.upscale_algorithm = self.upscale_algorithm
c.latent_format = self.latent_format c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae c.vae = self.vae
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
@ -135,6 +137,10 @@ class ControlBase:
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
return out return out
def set_extra_arg(self, argument, value=None):
self.extra_args[argument] = value
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
@ -191,7 +197,7 @@ class ControlNet(ControlBase):
timestep = self.model_sampling_current.timestep(t) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
return self.control_merge(control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype)
def copy(self): def copy(self):

View File

@ -0,0 +1,37 @@
UNION_CONTROLNET_TYPES = {"auto": -1,
"openpose": 0,
"depth": 1,
"hed/pidi/scribble/ted": 2,
"canny/lineart/anime_lineart/mlsd": 3,
"normal": 4,
"segment": 5,
"tile": 6,
"repaint": 7,
}
class SetUnionControlNetType:
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
"type": (list(UNION_CONTROLNET_TYPES.keys()),)
}}
CATEGORY = "conditioning"
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "set_controlnet_type"
def set_controlnet_type(self, control_net, type):
control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES[type]
if type_number >= 0:
control_net.set_extra_arg("control_type", [type_number])
else:
control_net.set_extra_arg("control_type", [])
return (control_net,)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
}

View File

@ -2036,6 +2036,7 @@ def init_builtin_extra_nodes():
"nodes_audio.py", "nodes_audio.py",
"nodes_sd3.py", "nodes_sd3.py",
"nodes_gits.py", "nodes_gits.py",
"nodes_controlnet.py",
] ]
import_failed = [] import_failed = []