diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 001de39f..cedf39d6 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -1,4 +1,5 @@ import comfy.utils +import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: @@ -67,8 +68,43 @@ class LatentMultiply: samples_out["samples"] = s1 * multiplier return (samples_out,) +class LatentInterpolate: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), + "samples2": ("LATENT",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, + "LatentInterpolate": LatentInterpolate, }