From 6cd8ffc465ed363b078249b081ea3f975e77cf15 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 8 Jun 2024 02:16:55 -0400 Subject: [PATCH] Reshape the empty latent image to the right amount of channels if needed. --- comfy/latent_formats.py | 2 ++ comfy/sample.py | 6 ++++++ comfy/utils.py | 10 +++++----- comfy_extras/nodes_custom_sampler.py | 2 ++ nodes.py | 2 ++ 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4ca466d9..69192bc6 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -2,6 +2,7 @@ import torch class LatentFormat: scale_factor = 1.0 + latent_channels = 4 latent_rgb_factors = None taesd_decoder_name = None @@ -72,6 +73,7 @@ class SD_X4(LatentFormat): ] class SC_Prior(LatentFormat): + latent_channels = 16 def __init__(self): self.scale_factor = 1.0 self.latent_rgb_factors = [ diff --git a/comfy/sample.py b/comfy/sample.py index e51bd67d..98dcaca7 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -24,6 +24,12 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises = torch.cat(noises, axis=0) return noises +def fix_empty_latent_channels(model, latent_image): + latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels + if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1) + return latent_image + def prepare_sampling(model, noise_shape, positive, negative, noise_mask): logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") return model, positive, negative, noise_mask, [] diff --git a/comfy/utils.py b/comfy/utils.py index ab47b8f2..884404cc 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -249,11 +249,11 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map -def repeat_to_batch_size(tensor, batch_size): - if tensor.shape[0] > batch_size: - return tensor[:batch_size] - elif tensor.shape[0] < batch_size: - return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] +def repeat_to_batch_size(tensor, batch_size, dim=0): + if tensor.shape[dim] > batch_size: + return tensor.narrow(dim, 0, batch_size) + elif tensor.shape[dim] < batch_size: + return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size) return tensor def resize_to_batch_size(tensor, batch_size): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 47f08bf6..45ef8cf4 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -380,6 +380,7 @@ class SamplerCustom: def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) if not add_noise: noise = Noise_EmptyNoise().generate_noise(latent) else: @@ -538,6 +539,7 @@ class SamplerCustomAdvanced: def sample(self, noise, guider, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) noise_mask = None if "noise_mask" in latent: diff --git a/nodes.py b/nodes.py index f454ff8c..b744b53f 100644 --- a/nodes.py +++ b/nodes.py @@ -1299,6 +1299,8 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: