diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2eefc4c5..b7fd8cd6 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -122,10 +122,34 @@ class LatentBatch: samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) return (samples_out,) +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, }