blend supports any size, dither -> quantize
This commit is contained in:
parent
4c7a9dbcb6
commit
fa2febc062
|
@ -1,5 +1,7 @@
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class Blend:
|
class Blend:
|
||||||
|
@ -28,6 +30,9 @@ class Blend:
|
||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||||
|
if image1.shape != image2.shape:
|
||||||
|
image2 = self.crop_and_resize(image2, image1.shape)
|
||||||
|
|
||||||
blended_image = self.blend_mode(image1, image2, blend_mode)
|
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||||
blended_image = torch.clamp(blended_image, 0, 1)
|
blended_image = torch.clamp(blended_image, 0, 1)
|
||||||
|
@ -50,6 +55,29 @@ class Blend:
|
||||||
def g(self, x):
|
def g(self, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
|
def crop_and_resize(self, img: torch.Tensor, target_shape: tuple):
|
||||||
|
batch_size, img_h, img_w, img_c = img.shape
|
||||||
|
_, target_h, target_w, _ = target_shape
|
||||||
|
img_aspect_ratio = img_w / img_h
|
||||||
|
target_aspect_ratio = target_w / target_h
|
||||||
|
|
||||||
|
# Crop center of the image to the target aspect ratio
|
||||||
|
if img_aspect_ratio > target_aspect_ratio:
|
||||||
|
new_width = int(img_h * target_aspect_ratio)
|
||||||
|
left = (img_w - new_width) // 2
|
||||||
|
img = img[:, :, left:left + new_width, :]
|
||||||
|
else:
|
||||||
|
new_height = int(img_w / target_aspect_ratio)
|
||||||
|
top = (img_h - new_height) // 2
|
||||||
|
img = img[:, top:top + new_height, :, :]
|
||||||
|
|
||||||
|
# Resize to target size
|
||||||
|
img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
|
img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
||||||
|
img = img.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
class Blur:
|
class Blur:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -100,7 +128,7 @@ class Blur:
|
||||||
|
|
||||||
return (blurred,)
|
return (blurred,)
|
||||||
|
|
||||||
class Dither:
|
class Quantize:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -109,51 +137,37 @@ class Dither:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"bits": ("INT", {
|
"colors": ("INT", {
|
||||||
"default": 4,
|
"default": 256,
|
||||||
"min": 1,
|
"min": 1,
|
||||||
"max": 8,
|
"max": 256,
|
||||||
"step": 1
|
"step": 1
|
||||||
}),
|
}),
|
||||||
|
"dither": (["none", "floyd-steinberg"],),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "dither"
|
FUNCTION = "quantize"
|
||||||
|
|
||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def dither(self, image: torch.Tensor, bits: int):
|
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
||||||
batch_size, height, width, _ = image.shape
|
batch_size, height, width, _ = image.shape
|
||||||
result = torch.zeros_like(image)
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
tensor_image = image[b]
|
tensor_image = image[b]
|
||||||
img = (tensor_image * 255)
|
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||||
height, width, _ = img.shape
|
pil_image = Image.fromarray(img, mode='RGB')
|
||||||
|
|
||||||
scale = 255 / (2**bits - 1)
|
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||||
|
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
||||||
|
|
||||||
for y in range(height):
|
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||||
for x in range(width):
|
result[b] = quantized_array
|
||||||
old_pixel = img[y, x].clone()
|
|
||||||
new_pixel = torch.round(old_pixel / scale) * scale
|
|
||||||
img[y, x] = new_pixel
|
|
||||||
|
|
||||||
quant_error = old_pixel - new_pixel
|
|
||||||
|
|
||||||
if x + 1 < width:
|
|
||||||
img[y, x + 1] += quant_error * 7 / 16
|
|
||||||
if y + 1 < height:
|
|
||||||
if x - 1 >= 0:
|
|
||||||
img[y + 1, x - 1] += quant_error * 3 / 16
|
|
||||||
img[y + 1, x] += quant_error * 5 / 16
|
|
||||||
if x + 1 < width:
|
|
||||||
img[y + 1, x + 1] += quant_error * 1 / 16
|
|
||||||
|
|
||||||
dithered = img / 255
|
|
||||||
tensor = dithered.unsqueeze(0)
|
|
||||||
result[b] = tensor
|
|
||||||
|
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
|
@ -210,6 +224,6 @@ class Sharpen:
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Blend": Blend,
|
"Blend": Blend,
|
||||||
"Blur": Blur,
|
"Blur": Blur,
|
||||||
"Dither": Dither,
|
"Quantize": Quantize,
|
||||||
"Sharpen": Sharpen,
|
"Sharpen": Sharpen,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue