feat: add support for HunYuanDit ControlNet (#4245)

* add support for HunYuanDit ControlNet

* fix hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix code format style

* add control_weight support for HunyuanDit Controlnet

* use control_weights in HunyuanDit Controlnet

* fix typo
This commit is contained in:
来新璐 2024-08-09 14:59:24 +08:00 committed by GitHub
parent 413322645e
commit 06eb9fb426
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 512 additions and 1 deletions

View File

@ -13,7 +13,7 @@ import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0] current_batch_size = tensor.shape[0]
@ -382,9 +382,116 @@ def load_controlnet_mmdit(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
class ControlNetWarperHunyuanDiT(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
def get_tensor(name):
if name in cond:
if isinstance(cond[name], torch.Tensor):
return cond[name].to(dtype)
else:
return cond[name]
else:
return None
encoder_hidden_states = get_tensor('c_crossattn')
text_embedding_mask = get_tensor('text_embedding_mask')
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
image_meta_size = get_tensor('image_meta_size')
style = get_tensor('style')
cos_cis_img = get_tensor('cos_cis_img')
sin_cis_img = get_tensor('sin_cis_img')
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(
x=x_noisy.to(dtype),
t=timestep.float(),
condition=self.cond_hint,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
**self.extra_args
)
return self.control_merge(control, control_prev, output_dtype)
def copy(self):
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_hunyuandit(controlnet_data):
supported_inference_dtypes = [torch.float16, torch.float32]
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
missing, unexpected = control_model.load_state_dict(controlnet_data)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
latent_format = comfy.latent_formats.SDXL()
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data) return ControlLora(controlnet_data)

View File

@ -0,0 +1,348 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
def __init__(
self,
input_size: tuple = 128,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1408,
depth: int = 40,
num_heads: int = 16,
mlp_ratio: float = 4.3637,
text_states_dim=1024,
text_states_dim_t5=2048,
text_len=77,
text_len_t5=256,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
size_cond=False,
use_style_cond=False,
learn_sigma=True,
norm="layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
self.latent_format = comfy.latent_formats.SDXL
self.mlp_t5 = nn.Sequential(
nn.Linear(
self.text_states_dim_t5,
self.text_states_dim_t5 * 4,
bias=True,
dtype=dtype,
device=device,
),
nn.SiLU(),
nn.Linear(
self.text_states_dim_t5 * 4,
self.text_states_dim,
bias=True,
dtype=dtype,
device=device,
),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=dtype,
device=device,
)
)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
dtype=dtype,
device=device,
operations=operations,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(
1, hidden_size, dtype=dtype, device=device
)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
dtype=dtype,
device=device,
operations=operations,
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations
)
self.extra_embedder = nn.Sequential(
operations.Linear(
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
),
nn.SiLU(),
operations.Linear(
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=False,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(19)
]
)
# Input zero linear for the first block
self.before_proj = zero_module(
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
)
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
zero_module(
nn.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
)
for _ in range(len(self.blocks))
]
)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor = None,
condition=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
control_weight=1.0,
transformer_options=None,
**kwarg,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:, -self.text_len :] = torch.where(
text_states_mask[:, -self.text_len :].unsqueeze(2),
text_states[:, -self.text_len :],
padding[: self.text_len],
)
text_states_t5[:, -self.text_len_t5 :] = torch.where(
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
text_states_t5[:, -self.text_len_t5 :],
padding[self.text_len :],
)
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# _, _, oh, ow = x.shape
# th, tw = oh // self.patch_size, ow // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(
x, self.patch_size, self.hidden_size // self.num_heads
) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=self.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
# if image_meta_size is not None:
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
# if image_meta_size.dtype != self.dtype:
# image_meta_size = image_meta_size.half()
# image_meta_size = image_meta_size.view(-1, 6 * 256)
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Deal with Condition =========================
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
assert len(control_weights) == len(
controls
), "control_weights and controls should have the same length"
controls = [
control * weight for control, weight in zip(controls, control_weights)
]
return {"output": controls}

View File

@ -91,6 +91,8 @@ class HunYuanDiTBlock(nn.Module):
# Long Skip Connection # Long Skip Connection
if self.skip_linear is not None: if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1) cat = torch.cat([x, skip], dim=-1)
if cat.dtype != x.dtype:
cat = cat.to(x.dtype)
cat = self.skip_norm(cat) cat = self.skip_norm(cat)
x = self.skip_linear(cat) x = self.skip_linear(cat)
@ -362,6 +364,8 @@ class HunYuanDiT(nn.Module):
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None controls = None
if control:
controls = control.get("output", None)
# ========================= Forward pass through HunYuanDiT blocks ========================= # ========================= Forward pass through HunYuanDiT blocks =========================
skips = [] skips = []
for layer, block in enumerate(self.blocks): for layer, block in enumerate(self.blocks):

View File

@ -19,6 +19,58 @@ class CLIPTextEncodeHunyuanDiT:
cond = output.pop("cond") cond = output.pop("cond")
return ([[cond, output]], ) return ([[cond, output]], )
class ControlNetApplyAdvancedHunYuan:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"vae": ("VAE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None):
if strength == 0:
return (positive, negative)
control_hint = image.movedim(-1,1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net.set_extra_arg('control_weight', control_weight)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1])
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan,
} }