Reshape the empty latent image to the right amount of channels if needed.

This commit is contained in:
comfyanonymous 2024-06-08 02:16:55 -04:00
parent 56333d4850
commit 6cd8ffc465
5 changed files with 17 additions and 5 deletions

View File

@ -2,6 +2,7 @@ import torch
class LatentFormat: class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None latent_rgb_factors = None
taesd_decoder_name = None taesd_decoder_name = None
@ -72,6 +73,7 @@ class SD_X4(LatentFormat):
] ]
class SC_Prior(LatentFormat): class SC_Prior(LatentFormat):
latent_channels = 16
def __init__(self): def __init__(self):
self.scale_factor = 1.0 self.scale_factor = 1.0
self.latent_rgb_factors = [ self.latent_rgb_factors = [

View File

@ -24,6 +24,12 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0) noises = torch.cat(noises, axis=0)
return noises return noises
def fix_empty_latent_channels(model, latent_image):
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
return latent_image
def prepare_sampling(model, noise_shape, positive, negative, noise_mask): def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
return model, positive, negative, noise_mask, [] return model, positive, negative, noise_mask, []

View File

@ -249,11 +249,11 @@ def unet_to_diffusers(unet_config):
return diffusers_unet_map return diffusers_unet_map
def repeat_to_batch_size(tensor, batch_size): def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[0] > batch_size: if tensor.shape[dim] > batch_size:
return tensor[:batch_size] return tensor.narrow(dim, 0, batch_size)
elif tensor.shape[0] < batch_size: elif tensor.shape[dim] < batch_size:
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
return tensor return tensor
def resize_to_batch_size(tensor, batch_size): def resize_to_batch_size(tensor, batch_size):

View File

@ -380,6 +380,7 @@ class SamplerCustom:
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
latent = latent_image latent = latent_image
latent_image = latent["samples"] latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
if not add_noise: if not add_noise:
noise = Noise_EmptyNoise().generate_noise(latent) noise = Noise_EmptyNoise().generate_noise(latent)
else: else:
@ -538,6 +539,7 @@ class SamplerCustomAdvanced:
def sample(self, noise, guider, sampler, sigmas, latent_image): def sample(self, noise, guider, sampler, sigmas, latent_image):
latent = latent_image latent = latent_image
latent_image = latent["samples"] latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:

View File

@ -1299,6 +1299,8 @@ class SetLatentNoiseMask:
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"] latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else: