Make LatentUpscale nodes work on 3d latents.

This commit is contained in:
comfyanonymous 2024-10-26 01:50:51 -04:00
parent d605677b33
commit c3ffbae067
2 changed files with 21 additions and 10 deletions

View File

@ -690,9 +690,14 @@ def lanczos(samples, width, height):
return result.to(samples.device, samples.dtype) return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
orig_shape = tuple(samples.shape)
if len(orig_shape) > 4:
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
samples = samples.movedim(2, 1)
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
if crop == "center": if crop == "center":
old_width = samples.shape[3] old_width = samples.shape[-1]
old_height = samples.shape[2] old_height = samples.shape[-2]
old_aspect = old_width / old_height old_aspect = old_width / old_height
new_aspect = width / height new_aspect = width / height
x = 0 x = 0
@ -701,16 +706,22 @@ def common_upscale(samples, width, height, upscale_method, crop):
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect: elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x] s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
else: else:
s = samples s = samples
if upscale_method == "bislerp": if upscale_method == "bislerp":
return bislerp(s, width, height) out = bislerp(s, width, height)
elif upscale_method == "lanczos": elif upscale_method == "lanczos":
return lanczos(s, width, height) out = lanczos(s, width, height)
else: else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
if len(orig_shape) == 4:
return out
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap)) rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))

View File

@ -1179,10 +1179,10 @@ class LatentUpscale:
if width == 0: if width == 0:
height = max(64, height) height = max(64, height)
width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2])) width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
elif height == 0: elif height == 0:
width = max(64, width) width = max(64, width)
height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3])) height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
else: else:
width = max(64, width) width = max(64, width)
height = max(64, height) height = max(64, height)
@ -1204,8 +1204,8 @@ class LatentUpscaleBy:
def upscale(self, samples, upscale_method, scale_by): def upscale(self, samples, upscale_method, scale_by):
s = samples.copy() s = samples.copy()
width = round(samples["samples"].shape[3] * scale_by) width = round(samples["samples"].shape[-1] * scale_by)
height = round(samples["samples"].shape[2] * scale_by) height = round(samples["samples"].shape[-2] * scale_by)
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
return (s,) return (s,)