diff --git a/comfy/cldm/mmdit.py b/comfy/cldm/mmdit.py index 6e72474c..ee0282bc 100644 --- a/comfy/cldm/mmdit.py +++ b/comfy/cldm/mmdit.py @@ -1,7 +1,6 @@ import torch from typing import Dict, Optional import comfy.ldm.modules.diffusionmodules.mmdit -import comfy.latent_formats class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): def __init__( @@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): operations=operations ) - self.latent_format = comfy.latent_formats.SD3() - def forward( self, x: torch.Tensor, @@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): ) -> torch.Tensor: #weird sd3 controlnet specific stuff - hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint) y = torch.zeros_like(y) - if self.context_processor is not None: context = self.context_processor(context) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9202c319..d0039513 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -7,6 +7,7 @@ import comfy.model_management import comfy.model_detection import comfy.model_patcher import comfy.ops +import comfy.latent_formats import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -38,6 +39,8 @@ class ControlBase: self.cond_hint = None self.strength = 1.0 self.timestep_percent_range = (0.0, 1.0) + self.latent_format = None + self.vae = None self.global_average_pooling = False self.timestep_range = None self.compression_ratio = 8 @@ -48,10 +51,12 @@ class ControlBase: self.device = device self.previous_controlnet = None - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): self.cond_hint_original = cond_hint self.strength = strength self.timestep_percent_range = timestep_percent_range + if self.latent_format is not None: + self.vae = vae return self def pre_run(self, model, percent_to_timestep_function): @@ -84,6 +89,8 @@ class ControlBase: c.global_average_pooling = self.global_average_pooling c.compression_ratio = self.compression_ratio c.upscale_algorithm = self.upscale_algorithm + c.latent_format = self.latent_format + c.vae = self.vae def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -129,7 +136,7 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device @@ -140,6 +147,7 @@ class ControlNet(ControlBase): self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype + self.latent_format = latent_format def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -162,7 +170,17 @@ class ControlNet(ControlBase): if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device) + compression_ratio = self.compression_ratio + if self.vae is not None: + compression_ratio *= self.vae.downscale_ratio + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + if self.vae is not None: + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) + comfy.model_management.load_models_gpu(loaded_models) + if self.latent_format is not None: + self.cond_hint = self.latent_format.process_in(self.cond_hint) + self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) @@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd): if len(unexpected) > 0: logging.debug("unexpected controlnet keys: {}".format(unexpected)) - control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + latent_format = comfy.latent_formats.SD3() + latent_format.shift_factor = 0 #SD3 controlnet weirdness + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index d0303aec..548b1ad6 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -80,8 +80,23 @@ class CLIPTextEncodeSD3: return ([[cond, {"pooled_output": pooled}]], ) +class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "vae": ("VAE", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + CATEGORY = "_for_testing/sd3" + NODE_CLASS_MAPPINGS = { "TripleCLIPLoader": TripleCLIPLoader, "EmptySD3LatentImage": EmptySD3LatentImage, "CLIPTextEncodeSD3": CLIPTextEncodeSD3, + "ControlNetApplySD3": ControlNetApplySD3, } diff --git a/nodes.py b/nodes.py index 99645b81..04775d29 100644 --- a/nodes.py +++ b/nodes.py @@ -783,7 +783,7 @@ class ControlNetApplyAdvanced: CATEGORY = "conditioning" - def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent): + def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): if strength == 0: return (positive, negative) @@ -800,7 +800,7 @@ class ControlNetApplyAdvanced: if prev_cnet in cnets: c_net = cnets[prev_cnet] else: - c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent)) + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) c_net.set_previous_controlnet(prev_cnet) cnets[prev_cnet] = c_net