diff --git a/comfy/cldm/mmdit.py b/comfy/cldm/mmdit.py index 025c2fb5..54a58ab8 100644 --- a/comfy/cldm/mmdit.py +++ b/comfy/cldm/mmdit.py @@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): def __init__( self, num_blocks = None, + control_latent_channels = None, dtype = None, device = None, operations = None, @@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): for _ in range(len(self.joint_blocks)): self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) + if control_latent_channels is None: + control_latent_channels = self.in_channels + self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed( None, self.patch_size, - self.in_channels, + control_latent_channels, self.hidden_size, bias=True, strict_img_size=False, diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 86089196..1ea00ecc 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -79,13 +79,19 @@ class ControlBase: self.previous_controlnet = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT + self.concat_mask = False + self.extra_concat_orig = [] + self.extra_concat = None - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint self.strength = strength self.timestep_percent_range = timestep_percent_range if self.latent_format is not None: self.vae = vae + self.extra_concat_orig = extra_concat.copy() + if self.concat_mask and len(self.extra_concat_orig) == 0: + self.extra_concat_orig.append(torch.tensor([[[[1.0]]]])) return self def pre_run(self, model, percent_to_timestep_function): @@ -100,9 +106,9 @@ class ControlBase: def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None + + self.cond_hint = None + self.extra_concat = None self.timestep_range = None def get_models(self): @@ -123,6 +129,8 @@ class ControlBase: c.vae = self.vae c.extra_conds = self.extra_conds.copy() c.strength_type = self.strength_type + c.concat_mask = self.concat_mask + c.extra_concat_orig = self.extra_concat_orig.copy() def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -175,7 +183,7 @@ class ControlBase: class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): super().__init__(device) self.control_model = control_model self.load_device = load_device @@ -189,6 +197,7 @@ class ControlNet(ControlBase): self.latent_format = latent_format self.extra_conds += extra_conds self.strength_type = strength_type + self.concat_mask = concat_mask def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -220,6 +229,13 @@ class ControlNet(ControlBase): comfy.model_management.load_models_gpu(loaded_models) if self.latent_format is not None: self.cond_hint = self.latent_format.process_in(self.cond_hint) + if len(self.extra_concat_orig) > 0: + to_concat = [] + for c in self.extra_concat_orig: + c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center") + to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0])) + self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1) + self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) @@ -410,12 +426,17 @@ def load_controlnet_mmdit(sd): for k in sd: new_sd[k] = sd[k] - control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + concat_mask = False + control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1] + if control_latent_channels == 17: #inpaint controlnet + concat_mask = True + + control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.SD3() latent_format.shift_factor = 0 #SD3 controlnet weirdness - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control @@ -450,13 +471,16 @@ def load_controlnet_flux_instantx(sd): num_union_modes = new_sd[union_cnet].shape[0] control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4 + concat_mask = False + if control_latent_channels == 17: + concat_mask = True control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.Flux() extra_conds = ['y', 'guidance'] - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control def convert_mistoline(sd): diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index 0773c8a5..7cf6ce60 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -1,4 +1,6 @@ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES +import nodes +import comfy.utils class SetUnionControlNetType: @classmethod @@ -22,6 +24,33 @@ class SetUnionControlNetType: return (control_net,) +class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced): + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "vae": ("VAE", ), + "image": ("IMAGE", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + FUNCTION = "apply_inpaint_controlnet" + + CATEGORY = "conditioning/controlnet" + + def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent): + mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) + mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round() + image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) + return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=[mask]) + + + NODE_CLASS_MAPPINGS = { "SetUnionControlNetType": SetUnionControlNetType, + "ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply, } diff --git a/nodes.py b/nodes.py index 0d3749d4..84524b6b 100644 --- a/nodes.py +++ b/nodes.py @@ -824,7 +824,7 @@ class ControlNetApplyAdvanced: CATEGORY = "conditioning/controlnet" - def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): + def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): if strength == 0: return (positive, negative) @@ -841,7 +841,7 @@ class ControlNetApplyAdvanced: if prev_cnet in cnets: c_net = cnets[prev_cnet] else: - c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat) c_net.set_previous_controlnet(prev_cnet) cnets[prev_cnet] = c_net