use common_upcale in blend
This commit is contained in:
parent
fa2febc062
commit
56196ab0f7
|
@ -3,6 +3,8 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
class Blend:
|
class Blend:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -31,7 +33,9 @@ class Blend:
|
||||||
|
|
||||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||||
if image1.shape != image2.shape:
|
if image1.shape != image2.shape:
|
||||||
image2 = self.crop_and_resize(image2, image1.shape)
|
image2 = image2.permute(0, 3, 1, 2)
|
||||||
|
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
|
||||||
|
image2 = image2.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
blended_image = self.blend_mode(image1, image2, blend_mode)
|
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||||
|
@ -55,29 +59,6 @@ class Blend:
|
||||||
def g(self, x):
|
def g(self, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
def crop_and_resize(self, img: torch.Tensor, target_shape: tuple):
|
|
||||||
batch_size, img_h, img_w, img_c = img.shape
|
|
||||||
_, target_h, target_w, _ = target_shape
|
|
||||||
img_aspect_ratio = img_w / img_h
|
|
||||||
target_aspect_ratio = target_w / target_h
|
|
||||||
|
|
||||||
# Crop center of the image to the target aspect ratio
|
|
||||||
if img_aspect_ratio > target_aspect_ratio:
|
|
||||||
new_width = int(img_h * target_aspect_ratio)
|
|
||||||
left = (img_w - new_width) // 2
|
|
||||||
img = img[:, :, left:left + new_width, :]
|
|
||||||
else:
|
|
||||||
new_height = int(img_w / target_aspect_ratio)
|
|
||||||
top = (img_h - new_height) // 2
|
|
||||||
img = img[:, top:top + new_height, :, :]
|
|
||||||
|
|
||||||
# Resize to target size
|
|
||||||
img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
|
||||||
img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
|
||||||
img = img.permute(0, 2, 3, 1)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
class Blur:
|
class Blur:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue