Make VAEDecodeTiled node work with video VAEs.
This commit is contained in:
parent
5e29e7a488
commit
b49616f951
21
comfy/sd.py
21
comfy/sd.py
|
@ -361,10 +361,25 @@ class VAE:
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||||
model_management.load_model_gpu(self.patcher)
|
model_management.load_model_gpu(self.patcher)
|
||||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
dims = samples.ndim - 2
|
||||||
return output.movedim(1,-1)
|
args = {}
|
||||||
|
if tile_x is not None:
|
||||||
|
args["tile_x"] = tile_x
|
||||||
|
if tile_y is not None:
|
||||||
|
args["tile_y"] = tile_y
|
||||||
|
if overlap is not None:
|
||||||
|
args["overlap"] = overlap
|
||||||
|
|
||||||
|
if dims == 1:
|
||||||
|
args.pop("tile_y")
|
||||||
|
output = self.decode_tiled_1d(samples, **args)
|
||||||
|
elif dims == 2:
|
||||||
|
output = self.decode_tiled_(samples, **args)
|
||||||
|
elif dims == 3:
|
||||||
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
|
|
12
nodes.py
12
nodes.py
|
@ -290,15 +290,21 @@ class VAEDecodeTiled:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||||
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
|
||||||
|
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def decode(self, vae, samples, tile_size):
|
def decode(self, vae, samples, tile_size, overlap):
|
||||||
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
|
if tile_size < overlap * 4:
|
||||||
|
overlap = tile_size // 4
|
||||||
|
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
|
||||||
|
if len(images.shape) == 5: #Combine batches
|
||||||
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
|
return (images, )
|
||||||
|
|
||||||
class VAEEncode:
|
class VAEEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue