Make VAEDecodeTiled node work with video VAEs.

This commit is contained in:
comfyanonymous 2024-11-07 03:47:12 -05:00
parent 5e29e7a488
commit b49616f951
2 changed files with 27 additions and 6 deletions

View File

@ -361,10 +361,25 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) dims = samples.ndim - 2
return output.movedim(1,-1) args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)

View File

@ -290,15 +290,21 @@ class VAEDecodeTiled:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) "tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size): def decode(self, vae, samples, tile_size, overlap):
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), ) if tile_size < overlap * 4:
overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
class VAEEncode: class VAEEncode:
@classmethod @classmethod