From c3ffbae0677f4b8caccd2cf363c54d47d9dae3a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 26 Oct 2024 01:50:51 -0400 Subject: [PATCH] Make LatentUpscale nodes work on 3d latents. --- comfy/utils.py | 23 +++++++++++++++++------ nodes.py | 8 ++++---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 7cef9044..056cf363 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -690,9 +690,14 @@ def lanczos(samples, width, height): return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): + orig_shape = tuple(samples.shape) + if len(orig_shape) > 4: + samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1]) + samples = samples.movedim(2, 1) + samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1]) if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] + old_width = samples.shape[-1] + old_height = samples.shape[-2] old_aspect = old_width / old_height new_aspect = width / height x = 0 @@ -701,16 +706,22 @@ def common_upscale(samples, width, height, upscale_method, crop): x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) elif old_aspect < new_aspect: y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] + s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2) else: s = samples if upscale_method == "bislerp": - return bislerp(s, width, height) + out = bislerp(s, width, height) elif upscale_method == "lanczos": - return lanczos(s, width, height) + out = lanczos(s, width, height) else: - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if len(orig_shape) == 4: + return out + + out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width)) + return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width)) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap)) diff --git a/nodes.py b/nodes.py index 15a78352..ff45acf8 100644 --- a/nodes.py +++ b/nodes.py @@ -1179,10 +1179,10 @@ class LatentUpscale: if width == 0: height = max(64, height) - width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2])) + width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2])) elif height == 0: width = max(64, width) - height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3])) + height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1])) else: width = max(64, width) height = max(64, height) @@ -1204,8 +1204,8 @@ class LatentUpscaleBy: def upscale(self, samples, upscale_method, scale_by): s = samples.copy() - width = round(samples["samples"].shape[3] * scale_by) - height = round(samples["samples"].shape[2] * scale_by) + width = round(samples["samples"].shape[-1] * scale_by) + height = round(samples["samples"].shape[-2] * scale_by) s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") return (s,)