Improve tiling calculations to reduce number of tiles that need to be processed. (#4944)
This commit is contained in:
parent
d514bb38ee
commit
0b7dfa986d
|
@ -713,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||||
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||||
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
|
@ -722,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
|
|
||||||
|
# handle entire input fitting in a single tile
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
output[b:b+1] = function(s).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
|
@ -734,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(pos * upscale_amount))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
feather = round(overlap * upscale_amount)
|
||||||
|
|
||||||
for t in range(feather):
|
for t in range(feather):
|
||||||
for d in range(2, dims + 2):
|
for d in range(2, dims + 2):
|
||||||
m = mask.narrow(d, t, 1)
|
a = (t + 1) / feather
|
||||||
m *= ((1.0/feather) * (t + 1))
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
m *= ((1.0/feather) * (t + 1))
|
|
||||||
|
|
||||||
o = out
|
o = out
|
||||||
o_d = out_div
|
o_d = out_div
|
||||||
|
@ -750,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
o += ps * mask
|
o.add_(ps * mask)
|
||||||
o_d += mask
|
o_d.add_(mask)
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
Loading…
Reference in New Issue