Support AliMama SD3 and Flux inpaint controlnets.
Use the ControlNetInpaintingAliMamaApply node.
This commit is contained in:
parent
369a6dd2c4
commit
f48e390032
|
@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
|
control_latent_channels = None,
|
||||||
dtype = None,
|
dtype = None,
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
|
@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
for _ in range(len(self.joint_blocks)):
|
for _ in range(len(self.joint_blocks)):
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
None,
|
None,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
self.in_channels,
|
control_latent_channels,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
strict_img_size=False,
|
strict_img_size=False,
|
||||||
|
|
|
@ -79,13 +79,19 @@ class ControlBase:
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
self.concat_mask = False
|
||||||
|
self.extra_concat_orig = []
|
||||||
|
self.extra_concat = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
self.extra_concat_orig = extra_concat.copy()
|
||||||
|
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||||
|
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
@ -100,9 +106,9 @@ class ControlBase:
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
self.cond_hint = None
|
||||||
self.cond_hint = None
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
|
@ -123,6 +129,8 @@ class ControlBase:
|
||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
c.extra_conds = self.extra_conds.copy()
|
c.extra_conds = self.extra_conds.copy()
|
||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
|
c.concat_mask = self.concat_mask
|
||||||
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
|
@ -175,7 +183,7 @@ class ControlBase:
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
|
@ -189,6 +197,7 @@ class ControlNet(ControlBase):
|
||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
|
self.concat_mask = concat_mask
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
|
@ -220,6 +229,13 @@ class ControlNet(ControlBase):
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
|
if len(self.extra_concat_orig) > 0:
|
||||||
|
to_concat = []
|
||||||
|
for c in self.extra_concat_orig:
|
||||||
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||||
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
@ -410,12 +426,17 @@ def load_controlnet_mmdit(sd):
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
concat_mask = False
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||||
|
if control_latent_channels == 17: #inpaint controlnet
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
@ -450,13 +471,16 @@ def load_controlnet_flux_instantx(sd):
|
||||||
num_union_modes = new_sd[union_cnet].shape[0]
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 17:
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.Flux()
|
latent_format = comfy.latent_formats.Flux()
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -22,6 +24,33 @@ class SetUnionControlNetType:
|
||||||
|
|
||||||
return (control_net,)
|
return (control_net,)
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "apply_inpaint_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||||
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
|
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||||
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
|
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=[mask])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||||
}
|
}
|
||||||
|
|
4
nodes.py
4
nodes.py
|
@ -824,7 +824,7 @@ class ControlNetApplyAdvanced:
|
||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
|
@ -841,7 +841,7 @@ class ControlNetApplyAdvanced:
|
||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue