From a475ec2300abb4eab845510ad0da596114174274 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 02:35:19 -0400 Subject: [PATCH] Cleanup HunyuanDit controlnets. Use the: ControlNetApply SD3 and HunyuanDiT node. --- comfy/controlnet.py | 145 ++++++++++------------------------ comfy/ldm/hydit/controlnet.py | 53 +++---------- comfy_extras/nodes_hunyuan.py | 51 ------------ comfy_extras/nodes_sd3.py | 5 ++ 4 files changed, 60 insertions(+), 194 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index c11a759e..89c3c17e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,4 +1,24 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + + import torch +from enum import Enum import math import os import logging @@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) +class StrengthType(Enum): + CONSTANT = 1 + LINEAR_UP = 2 + class ControlBase: def __init__(self, device=None): self.cond_hint_original = None @@ -51,6 +75,8 @@ class ControlBase: device = comfy.model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.extra_conds = [] + self.strength_type = StrengthType.CONSTANT def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): self.cond_hint_original = cond_hint @@ -93,6 +119,8 @@ class ControlBase: c.latent_format = self.latent_format c.extra_args = self.extra_args.copy() c.vae = self.vae + c.extra_conds = self.extra_conds.copy() + c.strength_type = self.strength_type def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -113,7 +141,10 @@ class ControlBase: if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once applied_to.add(x) - x *= self.strength + if self.strength_type == StrengthType.CONSTANT: + x *= self.strength + elif self.strength_type == StrengthType.LINEAR_UP: + x *= (self.strength ** float(len(control_output) - i)) if x.dtype != output_dtype: x = x.to(output_dtype) @@ -142,7 +173,7 @@ class ControlBase: class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT): super().__init__(device) self.control_model = control_model self.load_device = load_device @@ -154,6 +185,8 @@ class ControlNet(ControlBase): self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype self.latent_format = latent_format + self.extra_conds += extra_conds + self.strength_type = strength_type def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -192,7 +225,7 @@ class ControlNet(ControlBase): context = cond.get('crossattn_controlnet', cond['c_crossattn']) extra = self.extra_args.copy() - for c in ["y", "guidance"]: #TODO + for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: extra[c] = temp.to(dtype) @@ -382,116 +415,22 @@ def load_controlnet_mmdit(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control -class ControlNetWarperHunyuanDiT(ControlNet): - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) +def load_controlnet_hunyuandit(controlnet_data): + model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data) - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return None - - dtype = self.control_model.dtype - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - - output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - compression_ratio = self.compression_ratio - if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") - if self.vae is not None: - loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) - comfy.model_management.load_models_gpu(loaded_models) - if self.latent_format is not None: - self.cond_hint = self.latent_format.process_in(self.cond_hint) - self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - - def get_tensor(name): - if name in cond: - if isinstance(cond[name], torch.Tensor): - return cond[name].to(dtype) - else: - return cond[name] - else: - return None - - encoder_hidden_states = get_tensor('c_crossattn') - text_embedding_mask = get_tensor('text_embedding_mask') - encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5') - text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5') - image_meta_size = get_tensor('image_meta_size') - style = get_tensor('style') - cos_cis_img = get_tensor('cos_cis_img') - sin_cis_img = get_tensor('sin_cis_img') - - timestep = self.model_sampling_current.timestep(t) - x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - - control = self.control_model( - x=x_noisy.to(dtype), - t=timestep.float(), - condition=self.cond_hint, - encoder_hidden_states=encoder_hidden_states, - text_embedding_mask=text_embedding_mask, - encoder_hidden_states_t5=encoder_hidden_states_t5, - text_embedding_mask_t5=text_embedding_mask_t5, - image_meta_size=image_meta_size, - style=style, - cos_cis_img=cos_cis_img, - sin_cis_img=sin_cis_img, - **self.extra_args - ) - return self.control_merge(control, control_prev, output_dtype) - - def copy(self): - c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) - c.control_model = self.control_model - c.control_model_wrapped = self.control_model_wrapped - self.copy_to(c) - return c - -def load_controlnet_hunyuandit(controlnet_data): - - supported_inference_dtypes = [torch.float16, torch.float32] - - unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) - load_device = comfy.model_management.get_torch_device() - manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init - control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) - missing, unexpected = control_model.load_state_dict(controlnet_data) - - if len(missing) > 0: - logging.warning("missing controlnet keys: {}".format(missing)) - - if len(unexpected) > 0: - logging.debug("unexpected controlnet keys: {}".format(unexpected)) + control_model = controlnet_load_state_dict(control_model, controlnet_data) latent_format = comfy.latent_formats.SDXL() - control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img'] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP) return control def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data) - + if "lora_controlnet" in controlnet_data: return ControlLora(controlnet_data) diff --git a/comfy/ldm/hydit/controlnet.py b/comfy/ldm/hydit/controlnet.py index 0d3f7966..cd71fca3 100644 --- a/comfy/ldm/hydit/controlnet.py +++ b/comfy/ldm/hydit/controlnet.py @@ -16,28 +16,11 @@ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from .poolers import AttentionPool import comfy.latent_formats -from .models import HunYuanDiTBlock +from .models import HunYuanDiTBlock, calc_rope from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module - - -def calc_rope(x, patch_size, head_size): - th = (x.shape[2] + (patch_size // 2)) // patch_size - tw = (x.shape[3] + (patch_size // 2)) // patch_size - base_size = 512 // 8 // patch_size - start, stop = get_fill_resize_and_crop((th, tw), base_size) - sub_args = [start, stop, (th, tw)] - # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] - rope = get_2d_rotary_pos_embed(head_size, *sub_args) - return rope - - class HunYuanControlNet(nn.Module): """ HunYuanDiT: Diffusion model with a Transformer backbone. @@ -213,35 +196,32 @@ class HunYuanControlNet(nn.Module): ) # Input zero linear for the first block - self.before_proj = zero_module( - nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - ) + self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + # Output zero linear for the every block self.after_proj_list = nn.ModuleList( [ - zero_module( - nn.Linear( + + operations.Linear( self.hidden_size, self.hidden_size, dtype=dtype, device=device ) - ) for _ in range(len(self.blocks)) ] ) def forward( self, - x: torch.Tensor, - t: torch.Tensor = None, - condition=None, - encoder_hidden_states: Optional[torch.Tensor] = None, + x, + hint, + timesteps, + context,#encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, - control_weight=1.0, - transformer_options=None, + return_dict=False, **kwarg, ): """ @@ -270,10 +250,11 @@ class HunYuanControlNet(nn.Module): return_dict: bool Whether to return a dictionary. """ + condition = hint if condition.shape[0] == 1: condition = torch.repeat_interleave(condition, x.shape[0], dim=0) - text_states = encoder_hidden_states # 2,77,1024 + text_states = context # 2,77,1024 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_mask = text_embedding_mask.bool() # 2,77 text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 @@ -304,7 +285,7 @@ class HunYuanControlNet(nn.Module): ) # (cos_cis_img, sin_cis_img) # ========================= Build time and image embedding ========================= - t = self.t_embedder(t, dtype=self.dtype) + t = self.t_embedder(timesteps, dtype=self.dtype) x = self.x_embedder(x) # ========================= Concatenate all extra vectors ========================= @@ -337,12 +318,4 @@ class HunYuanControlNet(nn.Module): x = block(x, c, text_states, freqs_cis_img) controls.append(self.after_proj_list[layer](x)) # zero linear for output - control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)] - assert len(control_weights) == len( - controls - ), "control_weights and controls should have the same length" - controls = [ - control * weight for control, weight in zip(controls, control_weights) - ] - return {"output": controls} diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 4f2ccfe9..b03eaf6a 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -19,58 +19,7 @@ class CLIPTextEncodeHunyuanDiT: cond = output.pop("cond") return ([[cond, output]], ) - -class ControlNetApplyAdvancedHunYuan: - @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "image": ("IMAGE", ), - "vae": ("VAE", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - RETURN_TYPES = ("CONDITIONING","CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "apply_controlnet" - - CATEGORY = "conditioning/controlnet" - - def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None): - if strength == 0: - return (positive, negative) - - control_hint = image.movedim(-1,1) - cnets = {} - - out = [] - for conditioning in [positive, negative]: - c = [] - for t in conditioning: - d = t[1].copy() - - prev_cnet = d.get('control', None) - if prev_cnet in cnets: - c_net = cnets[prev_cnet] - else: - c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) - c_net.set_extra_arg('control_weight', control_weight) - - c_net.set_previous_controlnet(prev_cnet) - cnets[prev_cnet] = c_net - - d['control'] = c_net - d['control_apply_to_uncond'] = False - n = [t[0], d] - c.append(n) - out.append(c) - return (out[0], out[1]) - NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, - "ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan, } diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index ae9b8598..046096cb 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -100,3 +100,8 @@ NODE_CLASS_MAPPINGS = { "CLIPTextEncodeSD3": CLIPTextEncodeSD3, "ControlNetApplySD3": ControlNetApplySD3, } + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", +}