19 lines
755 B
Python
19 lines
755 B
Python
|
import torch
|
||
|
|
||
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||
|
if crop == "center":
|
||
|
old_width = samples.shape[3]
|
||
|
old_height = samples.shape[2]
|
||
|
old_aspect = old_width / old_height
|
||
|
new_aspect = width / height
|
||
|
x = 0
|
||
|
y = 0
|
||
|
if old_aspect > new_aspect:
|
||
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||
|
elif old_aspect < new_aspect:
|
||
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||
|
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||
|
else:
|
||
|
s = samples
|
||
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|