Add a LatentBatch node.
This commit is contained in:
parent
719fa0866f
commit
6596654d47
|
@ -3,9 +3,7 @@ import torch
|
|||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
latent.movedim(1, -1)
|
||||
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
||||
latent.movedim(-1, 1)
|
||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||
|
||||
|
||||
|
@ -102,9 +100,32 @@ class LatentInterpolate:
|
|||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "batch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def batch(self, samples1, samples2):
|
||||
samples_out = samples1.copy()
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
if s1.shape[1:] != s2.shape[1:]:
|
||||
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
|
||||
s = torch.cat((s1, s2), dim=0)
|
||||
samples_out["samples"] = s
|
||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||
return (samples_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
"LatentBatch": LatentBatch,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue