From 2b6b17817331a24afc7106bfe9ec3e2f9b03fab1 Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Tue, 19 Sep 2023 10:40:38 +0200 Subject: [PATCH] Added support for lanczos scaling --- comfy/utils.py | 11 +++++++++++ comfy_extras/nodes_post_processing.py | 2 +- nodes.py | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 3ed32e37..4e08bcb8 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,8 +1,10 @@ import torch +import torchvision import math import struct import comfy.checkpoint_pickle import safetensors.torch +from PIL import Image def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -346,6 +348,13 @@ def bislerp(samples, width, height): result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result +def lanczos(samples, width, height): + images = [torchvision.transforms.functional.to_pil_image(image) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torchvision.transforms.functional.to_tensor(image) for image in images] + result = torch.stack(images) + return result + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop): if upscale_method == "bislerp": return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) else: return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 51bdb24f..3f651e59 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -211,7 +211,7 @@ class Sharpen: return (result,) class ImageScaleToTotalPixels: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod diff --git a/nodes.py b/nodes.py index 9ccf179c..59c50a16 100644 --- a/nodes.py +++ b/nodes.py @@ -1423,7 +1423,7 @@ class LoadImageMask: return True class ImageScale: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod @@ -1444,7 +1444,7 @@ class ImageScale: return (s,) class ImageScaleBy: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod def INPUT_TYPES(s):