Added support for lanczos scaling

This commit is contained in:
MoonRide303 2023-09-19 10:40:38 +02:00
parent 6d3dee9d16
commit 2b6b178173
3 changed files with 14 additions and 3 deletions

View File

@ -1,8 +1,10 @@
import torch
import torchvision
import math
import struct
import comfy.checkpoint_pickle
import safetensors.torch
from PIL import Image
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@ -346,6 +348,13 @@ def bislerp(samples, width, height):
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
def lanczos(samples, width, height):
images = [torchvision.transforms.functional.to_pil_image(image) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torchvision.transforms.functional.to_tensor(image) for image in images]
result = torch.stack(images)
return result
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop):
if upscale_method == "bislerp":
return bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)

View File

@ -211,7 +211,7 @@ class Sharpen:
return (result,)
class ImageScaleToTotalPixels:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
@classmethod

View File

@ -1423,7 +1423,7 @@ class LoadImageMask:
return True
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
@classmethod
@ -1444,7 +1444,7 @@ class ImageScale:
return (s,)
class ImageScaleBy:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
def INPUT_TYPES(s):