From b49616f9511c57c8d54c4032e305d72352ac4ff5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Nov 2024 03:47:12 -0500 Subject: [PATCH] Make VAEDecodeTiled node work with video VAEs. --- comfy/sd.py | 21 ++++++++++++++++++--- nodes.py | 12 +++++++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d5c02e9a..7e76f6fa 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -361,10 +361,25 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None): model_management.load_model_gpu(self.patcher) - output = self.decode_tiled_(samples, tile_x, tile_y, overlap) - return output.movedim(1,-1) + dims = samples.ndim - 2 + args = {} + if tile_x is not None: + args["tile_x"] = tile_x + if tile_y is not None: + args["tile_y"] = tile_y + if overlap is not None: + args["overlap"] = overlap + + if dims == 1: + args.pop("tile_y") + output = self.decode_tiled_1d(samples, **args) + elif dims == 2: + output = self.decode_tiled_(samples, **args) + elif dims == 3: + output = self.decode_tiled_3d(samples, **args) + return output.movedim(1, -1) def encode(self, pixel_samples): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) diff --git a/nodes.py b/nodes.py index 6397654b..d48bbc28 100644 --- a/nodes.py +++ b/nodes.py @@ -290,15 +290,21 @@ class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), - "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) + "tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}), + "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "_for_testing" - def decode(self, vae, samples, tile_size): - return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), ) + def decode(self, vae, samples, tile_size, overlap): + if tile_size < overlap * 4: + overlap = tile_size // 4 + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) + if len(images.shape) == 5: #Combine batches + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + return (images, ) class VAEEncode: @classmethod