blend supports any size, dither -> quantize
This commit is contained in:
parent
4c7a9dbcb6
commit
fa2febc062
|
@ -1,5 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Blend:
|
||||
|
@ -28,6 +30,9 @@ class Blend:
|
|||
CATEGORY = "postprocessing"
|
||||
|
||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||
if image1.shape != image2.shape:
|
||||
image2 = self.crop_and_resize(image2, image1.shape)
|
||||
|
||||
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||
blended_image = torch.clamp(blended_image, 0, 1)
|
||||
|
@ -50,6 +55,29 @@ class Blend:
|
|||
def g(self, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def crop_and_resize(self, img: torch.Tensor, target_shape: tuple):
|
||||
batch_size, img_h, img_w, img_c = img.shape
|
||||
_, target_h, target_w, _ = target_shape
|
||||
img_aspect_ratio = img_w / img_h
|
||||
target_aspect_ratio = target_w / target_h
|
||||
|
||||
# Crop center of the image to the target aspect ratio
|
||||
if img_aspect_ratio > target_aspect_ratio:
|
||||
new_width = int(img_h * target_aspect_ratio)
|
||||
left = (img_w - new_width) // 2
|
||||
img = img[:, :, left:left + new_width, :]
|
||||
else:
|
||||
new_height = int(img_w / target_aspect_ratio)
|
||||
top = (img_h - new_height) // 2
|
||||
img = img[:, top:top + new_height, :, :]
|
||||
|
||||
# Resize to target size
|
||||
img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
|
||||
return img
|
||||
|
||||
class Blur:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -100,7 +128,7 @@ class Blur:
|
|||
|
||||
return (blurred,)
|
||||
|
||||
class Dither:
|
||||
class Quantize:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
@ -109,51 +137,37 @@ class Dither:
|
|||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"bits": ("INT", {
|
||||
"default": 4,
|
||||
"colors": ("INT", {
|
||||
"default": 256,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"max": 256,
|
||||
"step": 1
|
||||
}),
|
||||
"dither": (["none", "floyd-steinberg"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "dither"
|
||||
FUNCTION = "quantize"
|
||||
|
||||
CATEGORY = "postprocessing"
|
||||
|
||||
def dither(self, image: torch.Tensor, bits: int):
|
||||
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255)
|
||||
height, width, _ = img.shape
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
|
||||
scale = 255 / (2**bits - 1)
|
||||
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
||||
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
old_pixel = img[y, x].clone()
|
||||
new_pixel = torch.round(old_pixel / scale) * scale
|
||||
img[y, x] = new_pixel
|
||||
|
||||
quant_error = old_pixel - new_pixel
|
||||
|
||||
if x + 1 < width:
|
||||
img[y, x + 1] += quant_error * 7 / 16
|
||||
if y + 1 < height:
|
||||
if x - 1 >= 0:
|
||||
img[y + 1, x - 1] += quant_error * 3 / 16
|
||||
img[y + 1, x] += quant_error * 5 / 16
|
||||
if x + 1 < width:
|
||||
img[y + 1, x + 1] += quant_error * 1 / 16
|
||||
|
||||
dithered = img / 255
|
||||
tensor = dithered.unsqueeze(0)
|
||||
result[b] = tensor
|
||||
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||
result[b] = quantized_array
|
||||
|
||||
return (result,)
|
||||
|
||||
|
@ -210,6 +224,6 @@ class Sharpen:
|
|||
NODE_CLASS_MAPPINGS = {
|
||||
"Blend": Blend,
|
||||
"Blur": Blur,
|
||||
"Dither": Dither,
|
||||
"Quantize": Quantize,
|
||||
"Sharpen": Sharpen,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue