ControlNetApplySD3 node can now be used to use SD3 controlnets.
This commit is contained in:
parent
f8f7568d03
commit
264caca20e
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||||
import comfy.latent_formats
|
|
||||||
|
|
||||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.latent_format = comfy.latent_formats.SD3()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
#weird sd3 controlnet specific stuff
|
#weird sd3 controlnet specific stuff
|
||||||
hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint)
|
|
||||||
y = torch.zeros_like(y)
|
y = torch.zeros_like(y)
|
||||||
|
|
||||||
|
|
||||||
if self.context_processor is not None:
|
if self.context_processor is not None:
|
||||||
context = self.context_processor(context)
|
context = self.context_processor(context)
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
|
@ -38,6 +39,8 @@ class ControlBase:
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
self.timestep_percent_range = (0.0, 1.0)
|
self.timestep_percent_range = (0.0, 1.0)
|
||||||
|
self.latent_format = None
|
||||||
|
self.vae = None
|
||||||
self.global_average_pooling = False
|
self.global_average_pooling = False
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
|
@ -48,10 +51,12 @@ class ControlBase:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
|
if self.latent_format is not None:
|
||||||
|
self.vae = vae
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
@ -84,6 +89,8 @@ class ControlBase:
|
||||||
c.global_average_pooling = self.global_average_pooling
|
c.global_average_pooling = self.global_average_pooling
|
||||||
c.compression_ratio = self.compression_ratio
|
c.compression_ratio = self.compression_ratio
|
||||||
c.upscale_algorithm = self.upscale_algorithm
|
c.upscale_algorithm = self.upscale_algorithm
|
||||||
|
c.latent_format = self.latent_format
|
||||||
|
c.vae = self.vae
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
|
@ -129,7 +136,7 @@ class ControlBase:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
|
@ -140,6 +147,7 @@ class ControlNet(ControlBase):
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
self.latent_format = latent_format
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
|
@ -162,7 +170,17 @@ class ControlNet(ControlBase):
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
|
compression_ratio = self.compression_ratio
|
||||||
|
if self.vae is not None:
|
||||||
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
|
if self.vae is not None:
|
||||||
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||||
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
if self.latent_format is not None:
|
||||||
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
|
@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd):
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
latent_format = comfy.latent_formats.SD3()
|
||||||
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -80,8 +80,23 @@ class CLIPTextEncodeSD3:
|
||||||
return ([[cond, {"pooled_output": pooled}]], )
|
return ([[cond, {"pooled_output": pooled}]], )
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
CATEGORY = "_for_testing/sd3"
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
"EmptySD3LatentImage": EmptySD3LatentImage,
|
||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||||
|
"ControlNetApplySD3": ControlNetApplySD3,
|
||||||
}
|
}
|
||||||
|
|
4
nodes.py
4
nodes.py
|
@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
|
@ -800,7 +800,7 @@ class ControlNetApplyAdvanced:
|
||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue