From fa2febc0624678362cc758d316bb59afce9c8f06 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 3 Apr 2023 09:52:04 -0400 Subject: [PATCH] blend supports any size, dither -> quantize --- comfy_extras/nodes_post_processing.py | 74 ++++++++++++++++----------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3f3bddd7..322f3ca8 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,5 +1,7 @@ +import numpy as np import torch import torch.nn.functional as F +from PIL import Image class Blend: @@ -28,6 +30,9 @@ class Blend: CATEGORY = "postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + if image1.shape != image2.shape: + image2 = self.crop_and_resize(image2, image1.shape) + blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = torch.clamp(blended_image, 0, 1) @@ -50,6 +55,29 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + def crop_and_resize(self, img: torch.Tensor, target_shape: tuple): + batch_size, img_h, img_w, img_c = img.shape + _, target_h, target_w, _ = target_shape + img_aspect_ratio = img_w / img_h + target_aspect_ratio = target_w / target_h + + # Crop center of the image to the target aspect ratio + if img_aspect_ratio > target_aspect_ratio: + new_width = int(img_h * target_aspect_ratio) + left = (img_w - new_width) // 2 + img = img[:, :, left:left + new_width, :] + else: + new_height = int(img_w / target_aspect_ratio) + top = (img_h - new_height) // 2 + img = img[:, top:top + new_height, :, :] + + # Resize to target size + img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False) + img = img.permute(0, 2, 3, 1) + + return img + class Blur: def __init__(self): pass @@ -100,7 +128,7 @@ class Blur: return (blurred,) -class Dither: +class Quantize: def __init__(self): pass @@ -109,51 +137,37 @@ class Dither: return { "required": { "image": ("IMAGE",), - "bits": ("INT", { - "default": 4, + "colors": ("INT", { + "default": 256, "min": 1, - "max": 8, + "max": 256, "step": 1 }), + "dither": (["none", "floyd-steinberg"],), }, } RETURN_TYPES = ("IMAGE",) - FUNCTION = "dither" + FUNCTION = "quantize" CATEGORY = "postprocessing" - def dither(self, image: torch.Tensor, bits: int): + def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) + dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE + for b in range(batch_size): tensor_image = image[b] - img = (tensor_image * 255) - height, width, _ = img.shape + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') - scale = 255 / (2**bits - 1) + palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) - for y in range(height): - for x in range(width): - old_pixel = img[y, x].clone() - new_pixel = torch.round(old_pixel / scale) * scale - img[y, x] = new_pixel - - quant_error = old_pixel - new_pixel - - if x + 1 < width: - img[y, x + 1] += quant_error * 7 / 16 - if y + 1 < height: - if x - 1 >= 0: - img[y + 1, x - 1] += quant_error * 3 / 16 - img[y + 1, x] += quant_error * 5 / 16 - if x + 1 < width: - img[y + 1, x + 1] += quant_error * 1 / 16 - - dithered = img / 255 - tensor = dithered.unsqueeze(0) - result[b] = tensor + quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 + result[b] = quantized_array return (result,) @@ -210,6 +224,6 @@ class Sharpen: NODE_CLASS_MAPPINGS = { "Blend": Blend, "Blur": Blur, - "Dither": Dither, + "Quantize": Quantize, "Sharpen": Sharpen, }