Cleanup HunyuanDit controlnets.

Use the: ControlNetApply SD3 and HunyuanDiT node.
This commit is contained in:
comfyanonymous 2024-08-09 02:35:19 -04:00
parent 06eb9fb426
commit a475ec2300
4 changed files with 60 additions and 194 deletions

View File

@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
from enum import Enum
import math import math
import os import os
import logging import logging
@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else: else:
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase: class ControlBase:
def __init__(self, device=None): def __init__(self, device=None):
self.cond_hint_original = None self.cond_hint_original = None
@ -51,6 +75,8 @@ class ControlBase:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -93,6 +119,8 @@ class ControlBase:
c.latent_format = self.latent_format c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy() c.extra_args = self.extra_args.copy()
c.vae = self.vae c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -113,7 +141,10 @@ class ControlBase:
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x) applied_to.add(x)
x *= self.strength if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype: if x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
@ -142,7 +173,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@ -154,6 +185,8 @@ class ControlNet(ControlBase):
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -192,7 +225,7 @@ class ControlNet(ControlBase):
context = cond.get('crossattn_controlnet', cond['c_crossattn']) context = cond.get('crossattn_controlnet', cond['c_crossattn'])
extra = self.extra_args.copy() extra = self.extra_args.copy()
for c in ["y", "guidance"]: #TODO for c in self.extra_conds:
temp = cond.get(c, None) temp = cond.get(c, None)
if temp is not None: if temp is not None:
extra[c] = temp.to(dtype) extra[c] = temp.to(dtype)
@ -382,109 +415,15 @@ def load_controlnet_mmdit(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
class ControlNetWarperHunyuanDiT(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
def get_tensor(name):
if name in cond:
if isinstance(cond[name], torch.Tensor):
return cond[name].to(dtype)
else:
return cond[name]
else:
return None
encoder_hidden_states = get_tensor('c_crossattn')
text_embedding_mask = get_tensor('text_embedding_mask')
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
image_meta_size = get_tensor('image_meta_size')
style = get_tensor('style')
cos_cis_img = get_tensor('cos_cis_img')
sin_cis_img = get_tensor('sin_cis_img')
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(
x=x_noisy.to(dtype),
t=timestep.float(),
condition=self.cond_hint,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
**self.extra_args
)
return self.control_merge(control, control_prev, output_dtype)
def copy(self):
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_hunyuandit(controlnet_data): def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
supported_inference_dtypes = [torch.float16, torch.float32]
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
missing, unexpected = control_model.load_state_dict(controlnet_data) control_model = controlnet_load_state_dict(control_model, controlnet_data)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
latent_format = comfy.latent_formats.SDXL() latent_format = comfy.latent_formats.SDXL()
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP)
return control return control
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):

View File

@ -16,28 +16,11 @@ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool from .poolers import AttentionPool
import comfy.latent_formats import comfy.latent_formats
from .models import HunYuanDiTBlock from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
class HunYuanControlNet(nn.Module): class HunYuanControlNet(nn.Module):
""" """
HunYuanDiT: Diffusion model with a Transformer backbone. HunYuanDiT: Diffusion model with a Transformer backbone.
@ -213,35 +196,32 @@ class HunYuanControlNet(nn.Module):
) )
# Input zero linear for the first block # Input zero linear for the first block
self.before_proj = zero_module( self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
)
# Output zero linear for the every block # Output zero linear for the every block
self.after_proj_list = nn.ModuleList( self.after_proj_list = nn.ModuleList(
[ [
zero_module(
nn.Linear( operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device self.hidden_size, self.hidden_size, dtype=dtype, device=device
) )
)
for _ in range(len(self.blocks)) for _ in range(len(self.blocks))
] ]
) )
def forward( def forward(
self, self,
x: torch.Tensor, x,
t: torch.Tensor = None, hint,
condition=None, timesteps,
encoder_hidden_states: Optional[torch.Tensor] = None, context,#encoder_hidden_states=None,
text_embedding_mask=None, text_embedding_mask=None,
encoder_hidden_states_t5=None, encoder_hidden_states_t5=None,
text_embedding_mask_t5=None, text_embedding_mask_t5=None,
image_meta_size=None, image_meta_size=None,
style=None, style=None,
control_weight=1.0, return_dict=False,
transformer_options=None,
**kwarg, **kwarg,
): ):
""" """
@ -270,10 +250,11 @@ class HunYuanControlNet(nn.Module):
return_dict: bool return_dict: bool
Whether to return a dictionary. Whether to return a dictionary.
""" """
condition = hint
if condition.shape[0] == 1: if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0) condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = encoder_hidden_states # 2,77,1024 text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77 text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
@ -304,7 +285,7 @@ class HunYuanControlNet(nn.Module):
) # (cos_cis_img, sin_cis_img) ) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding ========================= # ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=self.dtype) t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x) x = self.x_embedder(x)
# ========================= Concatenate all extra vectors ========================= # ========================= Concatenate all extra vectors =========================
@ -337,12 +318,4 @@ class HunYuanControlNet(nn.Module):
x = block(x, c, text_states, freqs_cis_img) x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output controls.append(self.after_proj_list[layer](x)) # zero linear for output
control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
assert len(control_weights) == len(
controls
), "control_weights and controls should have the same length"
controls = [
control * weight for control, weight in zip(controls, control_weights)
]
return {"output": controls} return {"output": controls}

View File

@ -20,57 +20,6 @@ class CLIPTextEncodeHunyuanDiT:
return ([[cond, output]], ) return ([[cond, output]], )
class ControlNetApplyAdvancedHunYuan:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"vae": ("VAE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None):
if strength == 0:
return (positive, negative)
control_hint = image.movedim(-1,1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net.set_extra_arg('control_weight', control_weight)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1])
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan,
} }

View File

@ -100,3 +100,8 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, "CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3, "ControlNetApplySD3": ControlNetApplySD3,
} }
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
}