diff --git a/comfy/utils.py b/comfy/utils.py index 1bc35df7..78f8314f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -713,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): - return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) + rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap)) + cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap)) + return rows * cols @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): @@ -722,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_ for b in range(samples.shape[0]): s = samples[b:b+1] + + # handle entire input fitting in a single tile + if all(s.shape[d+2] <= tile[d] for d in range(dims)): + output[b:b+1] = function(s).to(output_device) + if pbar is not None: + pbar.update(1) + continue + out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) - for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): + positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] + + for it in itertools.product(*positions): s_in = s upscaled = [] @@ -734,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_ l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(pos * upscale_amount)) + ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) + for t in range(feather): for d in range(2, dims + 2): - m = mask.narrow(d, t, 1) - m *= ((1.0/feather) * (t + 1)) - m = mask.narrow(d, mask.shape[d] -1 -t, 1) - m *= ((1.0/feather) * (t + 1)) + a = (t + 1) / feather + mask.narrow(d, t, 1).mul_(a) + mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) o = out o_d = out_div @@ -750,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_ o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - o += ps * mask - o_d += mask + o.add_(ps * mask) + o_d.add_(mask) if pbar is not None: pbar.update(1)