From 06eb9fb426706550fe46aa4e36e2abcba9af241d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A5=E6=96=B0=E7=92=90?= <35400185+CrazyBoyM@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:59:24 +0800 Subject: [PATCH] feat: add support for HunYuanDit ControlNet (#4245) * add support for HunYuanDit ControlNet * fix hunyuandit controlnet * fix typo in hunyuandit controlnet * fix typo in hunyuandit controlnet * fix code format style * add control_weight support for HunyuanDit Controlnet * use control_weights in HunyuanDit Controlnet * fix typo --- comfy/controlnet.py | 109 ++++++++++- comfy/ldm/hydit/controlnet.py | 348 ++++++++++++++++++++++++++++++++++ comfy/ldm/hydit/models.py | 4 + comfy_extras/nodes_hunyuan.py | 52 +++++ 4 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/hydit/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 97e4f4d0..c11a759e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -13,7 +13,7 @@ import comfy.cldm.cldm import comfy.t2i_adapter.adapter import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit - +import comfy.ldm.hydit.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -382,9 +382,116 @@ def load_controlnet_mmdit(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control +class ControlNetWarperHunyuanDiT(ControlNet): + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return None + + dtype = self.control_model.dtype + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + + output_dtype = x_noisy.dtype + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + compression_ratio = self.compression_ratio + if self.vae is not None: + compression_ratio *= self.vae.downscale_ratio + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + if self.vae is not None: + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) + comfy.model_management.load_models_gpu(loaded_models) + if self.latent_format is not None: + self.cond_hint = self.latent_format.process_in(self.cond_hint) + self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + + def get_tensor(name): + if name in cond: + if isinstance(cond[name], torch.Tensor): + return cond[name].to(dtype) + else: + return cond[name] + else: + return None + + encoder_hidden_states = get_tensor('c_crossattn') + text_embedding_mask = get_tensor('text_embedding_mask') + encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5') + text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5') + image_meta_size = get_tensor('image_meta_size') + style = get_tensor('style') + cos_cis_img = get_tensor('cos_cis_img') + sin_cis_img = get_tensor('sin_cis_img') + + timestep = self.model_sampling_current.timestep(t) + x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) + + control = self.control_model( + x=x_noisy.to(dtype), + t=timestep.float(), + condition=self.cond_hint, + encoder_hidden_states=encoder_hidden_states, + text_embedding_mask=text_embedding_mask, + encoder_hidden_states_t5=encoder_hidden_states_t5, + text_embedding_mask_t5=text_embedding_mask_t5, + image_meta_size=image_meta_size, + style=style, + cos_cis_img=cos_cis_img, + sin_cis_img=sin_cis_img, + **self.extra_args + ) + return self.control_merge(control, control_prev, output_dtype) + + def copy(self): + c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + +def load_controlnet_hunyuandit(controlnet_data): + + supported_inference_dtypes = [torch.float16, torch.float32] + + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init + + control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) + missing, unexpected = control_model.load_state_dict(controlnet_data) + + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.debug("unexpected controlnet keys: {}".format(unexpected)) + + latent_format = comfy.latent_formats.SDXL() + control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + return control def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT + return load_controlnet_hunyuandit(controlnet_data) + if "lora_controlnet" in controlnet_data: return ControlLora(controlnet_data) diff --git a/comfy/ldm/hydit/controlnet.py b/comfy/ldm/hydit/controlnet.py new file mode 100644 index 00000000..0d3f7966 --- /dev/null +++ b/comfy/ldm/hydit/controlnet.py @@ -0,0 +1,348 @@ +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.utils import checkpoint + +from comfy.ldm.modules.diffusionmodules.mmdit import ( + Mlp, + TimestepEmbedder, + PatchEmbed, + RMSNorm, +) +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding +from .poolers import AttentionPool + +import comfy.latent_formats +from .models import HunYuanDiTBlock + +from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def calc_rope(x, patch_size, head_size): + th = (x.shape[2] + (patch_size // 2)) // patch_size + tw = (x.shape[3] + (patch_size // 2)) // patch_size + base_size = 512 // 8 // patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] + rope = get_2d_rotary_pos_embed(head_size, *sub_args) + return rope + + +class HunYuanControlNet(nn.Module): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline. + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + input_size: tuple + The size of the input image. + patch_size: int + The size of the patch. + in_channels: int + The number of input channels. + hidden_size: int + The hidden size of the transformer backbone. + depth: int + The number of transformer blocks. + num_heads: int + The number of attention heads. + mlp_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + log_fn: callable + The logging function. + """ + + def __init__( + self, + input_size: tuple = 128, + patch_size: int = 2, + in_channels: int = 4, + hidden_size: int = 1408, + depth: int = 40, + num_heads: int = 16, + mlp_ratio: float = 4.3637, + text_states_dim=1024, + text_states_dim_t5=2048, + text_len=77, + text_len_t5=256, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + size_cond=False, + use_style_cond=False, + learn_sigma=True, + norm="layer", + log_fn: callable = print, + attn_precision=None, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__() + self.log_fn = log_fn + self.depth = depth + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size = hidden_size + self.text_states_dim = text_states_dim + self.text_states_dim_t5 = text_states_dim_t5 + self.text_len = text_len + self.text_len_t5 = text_len_t5 + self.size_cond = size_cond + self.use_style_cond = use_style_cond + self.norm = norm + self.dtype = dtype + self.latent_format = comfy.latent_formats.SDXL + + self.mlp_t5 = nn.Sequential( + nn.Linear( + self.text_states_dim_t5, + self.text_states_dim_t5 * 4, + bias=True, + dtype=dtype, + device=device, + ), + nn.SiLU(), + nn.Linear( + self.text_states_dim_t5 * 4, + self.text_states_dim, + bias=True, + dtype=dtype, + device=device, + ), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn( + self.text_len + self.text_len_t5, + self.text_states_dim, + dtype=dtype, + device=device, + ) + ) + + # Attention pooling + pooler_out_dim = 1024 + self.pooler = AttentionPool( + self.text_len_t5, + self.text_states_dim_t5, + num_heads=8, + output_dim=pooler_out_dim, + dtype=dtype, + device=device, + operations=operations, + ) + + # Dimension of the extra input vectors + self.extra_in_dim = pooler_out_dim + + if self.size_cond: + # Image size and crop size conditions + self.extra_in_dim += 6 * 256 + + if self.use_style_cond: + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding( + 1, hidden_size, dtype=dtype, device=device + ) + self.extra_in_dim += hidden_size + + # Text embedding for `add` + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + hidden_size, + dtype=dtype, + device=device, + operations=operations, + ) + self.t_embedder = TimestepEmbedder( + hidden_size, dtype=dtype, device=device, operations=operations + ) + self.extra_embedder = nn.Sequential( + operations.Linear( + self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device + ), + nn.SiLU(), + operations.Linear( + hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device + ), + ) + + # Image embedding + num_patches = self.x_embedder.num_patches + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList( + [ + HunYuanDiTBlock( + hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + text_states_dim=self.text_states_dim, + qk_norm=qk_norm, + norm_type=self.norm, + skip=False, + attn_precision=attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(19) + ] + ) + + # Input zero linear for the first block + self.before_proj = zero_module( + nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + ) + + # Output zero linear for the every block + self.after_proj_list = nn.ModuleList( + [ + zero_module( + nn.Linear( + self.hidden_size, self.hidden_size, dtype=dtype, device=device + ) + ) + for _ in range(len(self.blocks)) + ] + ) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor = None, + condition=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + control_weight=1.0, + transformer_options=None, + **kwarg, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + if condition.shape[0] == 1: + condition = torch.repeat_interleave(condition, x.shape[0], dim=0) + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1) + + padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states) + + text_states[:, -self.text_len :] = torch.where( + text_states_mask[:, -self.text_len :].unsqueeze(2), + text_states[:, -self.text_len :], + padding[: self.text_len], + ) + text_states_t5[:, -self.text_len_t5 :] = torch.where( + text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2), + text_states_t5[:, -self.text_len_t5 :], + padding[self.text_len :], + ) + + text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024 + + # _, _, oh, ow = x.shape + # th, tw = oh // self.patch_size, ow // self.patch_size + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = calc_rope( + x, self.patch_size, self.hidden_size // self.num_heads + ) # (cos_cis_img, sin_cis_img) + + # ========================= Build time and image embedding ========================= + t = self.t_embedder(t, dtype=self.dtype) + x = self.x_embedder(x) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Build image meta size tokens if applicable + # if image_meta_size is not None: + # image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + # if image_meta_size.dtype != self.dtype: + # image_meta_size = image_meta_size.half() + # image_meta_size = image_meta_size.view(-1, 6 * 256) + # extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + if style is not None: + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec) # [B, D] + + # ========================= Deal with Condition ========================= + condition = self.x_embedder(condition) + + # ========================= Forward pass through HunYuanDiT blocks ========================= + controls = [] + x = x + self.before_proj(condition) # add condition + for layer, block in enumerate(self.blocks): + x = block(x, c, text_states, freqs_cis_img) + controls.append(self.after_proj_list[layer](x)) # zero linear for output + + control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)] + assert len(control_weights) == len( + controls + ), "control_weights and controls should have the same length" + controls = [ + control * weight for control, weight in zip(controls, control_weights) + ] + + return {"output": controls} diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index c70dbf92..9a1f3733 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -91,6 +91,8 @@ class HunYuanDiTBlock(nn.Module): # Long Skip Connection if self.skip_linear is not None: cat = torch.cat([x, skip], dim=-1) + if cat.dtype != x.dtype: + cat = cat.to(x.dtype) cat = self.skip_norm(cat) x = self.skip_linear(cat) @@ -362,6 +364,8 @@ class HunYuanDiT(nn.Module): c = t + self.extra_embedder(extra_vec) # [B, D] controls = None + if control: + controls = control.get("output", None) # ========================= Forward pass through HunYuanDiT blocks ========================= skips = [] for layer, block in enumerate(self.blocks): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index a3ac8cb0..4f2ccfe9 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -19,6 +19,58 @@ class CLIPTextEncodeHunyuanDiT: cond = output.pop("cond") return ([[cond, output]], ) + +class ControlNetApplyAdvancedHunYuan: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "vae": ("VAE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning/controlnet" + + def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None): + if strength == 0: + return (positive, negative) + + control_hint = image.movedim(-1,1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) + c_net.set_extra_arg('control_weight', control_weight) + + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1]) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan, }