Implement model part of flux union controlnet.

This commit is contained in:
comfyanonymous 2024-08-29 18:41:22 -04:00
parent ea3f39bd69
commit 10a79e9898
2 changed files with 19 additions and 3 deletions

View File

@ -444,7 +444,12 @@ def load_controlnet_flux_instantx(sd):
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()

View File

@ -14,7 +14,7 @@ import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
@ -29,6 +29,11 @@ class ControlNetFlux(Flux):
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@ -61,6 +66,7 @@ class ControlNetFlux(Flux):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@ -79,6 +85,11 @@ class ControlNetFlux(Flux):
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@ -137,4 +148,4 @@ class ControlNetFlux(Flux):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))