Added support for lanczos scaling
This commit is contained in:
parent
6d3dee9d16
commit
2b6b178173
|
@ -1,8 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
|
@ -346,6 +348,13 @@ def bislerp(samples, width, height):
|
||||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def lanczos(samples, width, height):
|
||||||
|
images = [torchvision.transforms.functional.to_pil_image(image) for image in samples]
|
||||||
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
|
images = [torchvision.transforms.functional.to_tensor(image) for image in images]
|
||||||
|
result = torch.stack(images)
|
||||||
|
return result
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[3]
|
||||||
|
@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
|
|
||||||
if upscale_method == "bislerp":
|
if upscale_method == "bislerp":
|
||||||
return bislerp(s, width, height)
|
return bislerp(s, width, height)
|
||||||
|
elif upscale_method == "lanczos":
|
||||||
|
return lanczos(s, width, height)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
|
|
@ -211,7 +211,7 @@ class Sharpen:
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
class ImageScaleToTotalPixels:
|
class ImageScaleToTotalPixels:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
4
nodes.py
4
nodes.py
|
@ -1423,7 +1423,7 @@ class LoadImageMask:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1444,7 +1444,7 @@ class ImageScale:
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class ImageScaleBy:
|
class ImageScaleBy:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|
Loading…
Reference in New Issue