feat: add support for HunYuanDit ControlNet (#4245)
* add support for HunYuanDit ControlNet * fix hunyuandit controlnet * fix typo in hunyuandit controlnet * fix typo in hunyuandit controlnet * fix code format style * add control_weight support for HunyuanDit Controlnet * use control_weights in HunyuanDit Controlnet * fix typo
This commit is contained in:
parent
413322645e
commit
06eb9fb426
|
@ -13,7 +13,7 @@ import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
|
import comfy.ldm.hydit.controlnet
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
|
@ -382,9 +382,116 @@ def load_controlnet_mmdit(sd):
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
class ControlNetWarperHunyuanDiT(ControlNet):
|
||||||
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
|
control_prev = None
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||||
|
|
||||||
|
if self.timestep_range is not None:
|
||||||
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
|
if control_prev is not None:
|
||||||
|
return control_prev
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype = self.control_model.dtype
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
|
output_dtype = x_noisy.dtype
|
||||||
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
|
if self.cond_hint is not None:
|
||||||
|
del self.cond_hint
|
||||||
|
self.cond_hint = None
|
||||||
|
compression_ratio = self.compression_ratio
|
||||||
|
if self.vae is not None:
|
||||||
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
|
if self.vae is not None:
|
||||||
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||||
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
if self.latent_format is not None:
|
||||||
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
|
def get_tensor(name):
|
||||||
|
if name in cond:
|
||||||
|
if isinstance(cond[name], torch.Tensor):
|
||||||
|
return cond[name].to(dtype)
|
||||||
|
else:
|
||||||
|
return cond[name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
encoder_hidden_states = get_tensor('c_crossattn')
|
||||||
|
text_embedding_mask = get_tensor('text_embedding_mask')
|
||||||
|
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
|
||||||
|
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
|
||||||
|
image_meta_size = get_tensor('image_meta_size')
|
||||||
|
style = get_tensor('style')
|
||||||
|
cos_cis_img = get_tensor('cos_cis_img')
|
||||||
|
sin_cis_img = get_tensor('sin_cis_img')
|
||||||
|
|
||||||
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
|
control = self.control_model(
|
||||||
|
x=x_noisy.to(dtype),
|
||||||
|
t=timestep.float(),
|
||||||
|
condition=self.cond_hint,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
text_embedding_mask=text_embedding_mask,
|
||||||
|
encoder_hidden_states_t5=encoder_hidden_states_t5,
|
||||||
|
text_embedding_mask_t5=text_embedding_mask_t5,
|
||||||
|
image_meta_size=image_meta_size,
|
||||||
|
style=style,
|
||||||
|
cos_cis_img=cos_cis_img,
|
||||||
|
sin_cis_img=sin_cis_img,
|
||||||
|
**self.extra_args
|
||||||
|
)
|
||||||
|
return self.control_merge(control, control_prev, output_dtype)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def load_controlnet_hunyuandit(controlnet_data):
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.float16, torch.float32]
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
if manual_cast_dtype is not None:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||||
|
missing, unexpected = control_model.load_state_dict(controlnet_data)
|
||||||
|
|
||||||
|
if len(missing) > 0:
|
||||||
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
|
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
|
return load_controlnet_hunyuandit(controlnet_data)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,348 @@
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch.utils import checkpoint
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||||
|
Mlp,
|
||||||
|
TimestepEmbedder,
|
||||||
|
PatchEmbed,
|
||||||
|
RMSNorm,
|
||||||
|
)
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from .poolers import AttentionPool
|
||||||
|
|
||||||
|
import comfy.latent_formats
|
||||||
|
from .models import HunYuanDiTBlock
|
||||||
|
|
||||||
|
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def calc_rope(x, patch_size, head_size):
|
||||||
|
th = (x.shape[2] + (patch_size // 2)) // patch_size
|
||||||
|
tw = (x.shape[3] + (patch_size // 2)) // patch_size
|
||||||
|
base_size = 512 // 8 // patch_size
|
||||||
|
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||||
|
sub_args = [start, stop, (th, tw)]
|
||||||
|
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||||
|
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||||
|
return rope
|
||||||
|
|
||||||
|
|
||||||
|
class HunYuanControlNet(nn.Module):
|
||||||
|
"""
|
||||||
|
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||||
|
|
||||||
|
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||||
|
|
||||||
|
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args: argparse.Namespace
|
||||||
|
The arguments parsed by argparse.
|
||||||
|
input_size: tuple
|
||||||
|
The size of the input image.
|
||||||
|
patch_size: int
|
||||||
|
The size of the patch.
|
||||||
|
in_channels: int
|
||||||
|
The number of input channels.
|
||||||
|
hidden_size: int
|
||||||
|
The hidden size of the transformer backbone.
|
||||||
|
depth: int
|
||||||
|
The number of transformer blocks.
|
||||||
|
num_heads: int
|
||||||
|
The number of attention heads.
|
||||||
|
mlp_ratio: float
|
||||||
|
The ratio of the hidden size of the MLP in the transformer block.
|
||||||
|
log_fn: callable
|
||||||
|
The logging function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: tuple = 128,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 4,
|
||||||
|
hidden_size: int = 1408,
|
||||||
|
depth: int = 40,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.3637,
|
||||||
|
text_states_dim=1024,
|
||||||
|
text_states_dim_t5=2048,
|
||||||
|
text_len=77,
|
||||||
|
text_len_t5=256,
|
||||||
|
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
||||||
|
size_cond=False,
|
||||||
|
use_style_cond=False,
|
||||||
|
learn_sigma=True,
|
||||||
|
norm="layer",
|
||||||
|
log_fn: callable = print,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.log_fn = log_fn
|
||||||
|
self.depth = depth
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.text_states_dim = text_states_dim
|
||||||
|
self.text_states_dim_t5 = text_states_dim_t5
|
||||||
|
self.text_len = text_len
|
||||||
|
self.text_len_t5 = text_len_t5
|
||||||
|
self.size_cond = size_cond
|
||||||
|
self.use_style_cond = use_style_cond
|
||||||
|
self.norm = norm
|
||||||
|
self.dtype = dtype
|
||||||
|
self.latent_format = comfy.latent_formats.SDXL
|
||||||
|
|
||||||
|
self.mlp_t5 = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
self.text_states_dim_t5,
|
||||||
|
self.text_states_dim_t5 * 4,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
self.text_states_dim_t5 * 4,
|
||||||
|
self.text_states_dim,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# learnable replace
|
||||||
|
self.text_embedding_padding = nn.Parameter(
|
||||||
|
torch.randn(
|
||||||
|
self.text_len + self.text_len_t5,
|
||||||
|
self.text_states_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention pooling
|
||||||
|
pooler_out_dim = 1024
|
||||||
|
self.pooler = AttentionPool(
|
||||||
|
self.text_len_t5,
|
||||||
|
self.text_states_dim_t5,
|
||||||
|
num_heads=8,
|
||||||
|
output_dim=pooler_out_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimension of the extra input vectors
|
||||||
|
self.extra_in_dim = pooler_out_dim
|
||||||
|
|
||||||
|
if self.size_cond:
|
||||||
|
# Image size and crop size conditions
|
||||||
|
self.extra_in_dim += 6 * 256
|
||||||
|
|
||||||
|
if self.use_style_cond:
|
||||||
|
# Here we use a default learned embedder layer for future extension.
|
||||||
|
self.style_embedder = nn.Embedding(
|
||||||
|
1, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.extra_in_dim += hidden_size
|
||||||
|
|
||||||
|
# Text embedding for `add`
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
input_size,
|
||||||
|
patch_size,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.extra_embedder = nn.Sequential(
|
||||||
|
operations.Linear(
|
||||||
|
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(
|
||||||
|
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image embedding
|
||||||
|
num_patches = self.x_embedder.num_patches
|
||||||
|
|
||||||
|
# HUnYuanDiT Blocks
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
HunYuanDiTBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
c_emb_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
text_states_dim=self.text_states_dim,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_type=self.norm,
|
||||||
|
skip=False,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(19)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input zero linear for the first block
|
||||||
|
self.before_proj = zero_module(
|
||||||
|
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output zero linear for the every block
|
||||||
|
self.after_proj_list = nn.ModuleList(
|
||||||
|
[
|
||||||
|
zero_module(
|
||||||
|
nn.Linear(
|
||||||
|
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(len(self.blocks))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor = None,
|
||||||
|
condition=None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
text_embedding_mask=None,
|
||||||
|
encoder_hidden_states_t5=None,
|
||||||
|
text_embedding_mask_t5=None,
|
||||||
|
image_meta_size=None,
|
||||||
|
style=None,
|
||||||
|
control_weight=1.0,
|
||||||
|
transformer_options=None,
|
||||||
|
**kwarg,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass of the encoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x: torch.Tensor
|
||||||
|
(B, D, H, W)
|
||||||
|
t: torch.Tensor
|
||||||
|
(B)
|
||||||
|
encoder_hidden_states: torch.Tensor
|
||||||
|
CLIP text embedding, (B, L_clip, D)
|
||||||
|
text_embedding_mask: torch.Tensor
|
||||||
|
CLIP text embedding mask, (B, L_clip)
|
||||||
|
encoder_hidden_states_t5: torch.Tensor
|
||||||
|
T5 text embedding, (B, L_t5, D)
|
||||||
|
text_embedding_mask_t5: torch.Tensor
|
||||||
|
T5 text embedding mask, (B, L_t5)
|
||||||
|
image_meta_size: torch.Tensor
|
||||||
|
(B, 6)
|
||||||
|
style: torch.Tensor
|
||||||
|
(B)
|
||||||
|
cos_cis_img: torch.Tensor
|
||||||
|
sin_cis_img: torch.Tensor
|
||||||
|
return_dict: bool
|
||||||
|
Whether to return a dictionary.
|
||||||
|
"""
|
||||||
|
if condition.shape[0] == 1:
|
||||||
|
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
||||||
|
|
||||||
|
text_states = encoder_hidden_states # 2,77,1024
|
||||||
|
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||||
|
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||||
|
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||||
|
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||||
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||||
|
|
||||||
|
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||||
|
|
||||||
|
text_states[:, -self.text_len :] = torch.where(
|
||||||
|
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||||
|
text_states[:, -self.text_len :],
|
||||||
|
padding[: self.text_len],
|
||||||
|
)
|
||||||
|
text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
||||||
|
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
||||||
|
text_states_t5[:, -self.text_len_t5 :],
|
||||||
|
padding[self.text_len :],
|
||||||
|
)
|
||||||
|
|
||||||
|
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||||
|
|
||||||
|
# _, _, oh, ow = x.shape
|
||||||
|
# th, tw = oh // self.patch_size, ow // self.patch_size
|
||||||
|
|
||||||
|
# Get image RoPE embedding according to `reso`lution.
|
||||||
|
freqs_cis_img = calc_rope(
|
||||||
|
x, self.patch_size, self.hidden_size // self.num_heads
|
||||||
|
) # (cos_cis_img, sin_cis_img)
|
||||||
|
|
||||||
|
# ========================= Build time and image embedding =========================
|
||||||
|
t = self.t_embedder(t, dtype=self.dtype)
|
||||||
|
x = self.x_embedder(x)
|
||||||
|
|
||||||
|
# ========================= Concatenate all extra vectors =========================
|
||||||
|
# Build text tokens with pooling
|
||||||
|
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||||
|
|
||||||
|
# Build image meta size tokens if applicable
|
||||||
|
# if image_meta_size is not None:
|
||||||
|
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
||||||
|
# if image_meta_size.dtype != self.dtype:
|
||||||
|
# image_meta_size = image_meta_size.half()
|
||||||
|
# image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||||
|
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||||
|
|
||||||
|
# Build style tokens
|
||||||
|
if style is not None:
|
||||||
|
style_embedding = self.style_embedder(style)
|
||||||
|
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||||
|
|
||||||
|
# Concatenate all extra vectors
|
||||||
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
|
# ========================= Deal with Condition =========================
|
||||||
|
condition = self.x_embedder(condition)
|
||||||
|
|
||||||
|
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||||
|
controls = []
|
||||||
|
x = x + self.before_proj(condition) # add condition
|
||||||
|
for layer, block in enumerate(self.blocks):
|
||||||
|
x = block(x, c, text_states, freqs_cis_img)
|
||||||
|
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
||||||
|
|
||||||
|
control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
|
||||||
|
assert len(control_weights) == len(
|
||||||
|
controls
|
||||||
|
), "control_weights and controls should have the same length"
|
||||||
|
controls = [
|
||||||
|
control * weight for control, weight in zip(controls, control_weights)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"output": controls}
|
|
@ -91,6 +91,8 @@ class HunYuanDiTBlock(nn.Module):
|
||||||
# Long Skip Connection
|
# Long Skip Connection
|
||||||
if self.skip_linear is not None:
|
if self.skip_linear is not None:
|
||||||
cat = torch.cat([x, skip], dim=-1)
|
cat = torch.cat([x, skip], dim=-1)
|
||||||
|
if cat.dtype != x.dtype:
|
||||||
|
cat = cat.to(x.dtype)
|
||||||
cat = self.skip_norm(cat)
|
cat = self.skip_norm(cat)
|
||||||
x = self.skip_linear(cat)
|
x = self.skip_linear(cat)
|
||||||
|
|
||||||
|
@ -362,6 +364,8 @@ class HunYuanDiT(nn.Module):
|
||||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
|
if control:
|
||||||
|
controls = control.get("output", None)
|
||||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||||
skips = []
|
skips = []
|
||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
|
|
|
@ -19,6 +19,58 @@ class CLIPTextEncodeHunyuanDiT:
|
||||||
cond = output.pop("cond")
|
cond = output.pop("cond")
|
||||||
return ([[cond, output]], )
|
return ([[cond, output]], )
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetApplyAdvancedHunYuan:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None):
|
||||||
|
if strength == 0:
|
||||||
|
return (positive, negative)
|
||||||
|
|
||||||
|
control_hint = image.movedim(-1,1)
|
||||||
|
cnets = {}
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
|
||||||
|
prev_cnet = d.get('control', None)
|
||||||
|
if prev_cnet in cnets:
|
||||||
|
c_net = cnets[prev_cnet]
|
||||||
|
else:
|
||||||
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
||||||
|
c_net.set_extra_arg('control_weight', control_weight)
|
||||||
|
|
||||||
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
d['control'] = c_net
|
||||||
|
d['control_apply_to_uncond'] = False
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return (out[0], out[1])
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
|
"ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue