Make nodes map over input lists (#579)

* allow nodes to map over lists

* make work with IS_CHANGED and VALIDATE_INPUTS

* give list outputs distinct socket shape

* add rebatch node

* add batch index logic

* add repeat latent batch

* deal with noise mask edge cases in latentfrombatch
This commit is contained in:
BlenderNeko 2023-05-13 17:15:45 +02:00 committed by GitHub
parent 997dd1b131
commit 1201d2eae5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 26 deletions

View File

@ -2,17 +2,26 @@ import torch
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import math import math
import numpy as np
def prepare_noise(latent_image, seed, skip=0): def prepare_noise(latent_image, seed, noise_inds=None):
""" """
creates random noise given a latent image and a seed. creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed optional arg skip can be used to skip and discard x number of noise generations for a given seed
""" """
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
for _ in range(skip): if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if i in unique_inds:
return noise noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" """ensures noise mask is of proper dimensions"""

View File

@ -0,0 +1,108 @@
import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
@staticmethod
def get_batch(latents, list_ind, offset):
'''prepare a batch out of the list of latents'''
samples = latents[list_ind]['samples']
shape = samples.shape
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
if mask.shape[0] < samples.shape[0]:
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
if 'batch_index' in latents[list_ind]:
batch_inds = latents[list_ind]['batch_index']
else:
batch_inds = [x+offset for x in range(shape[0])]
return samples, mask, batch_inds
@staticmethod
def get_slices(indexable, num, batch_size):
'''divides an indexable object into num slices of length batch_size, and a remainder'''
slices = []
for i in range(num):
slices.append(indexable[i*batch_size:(i+1)*batch_size])
if num * batch_size < len(indexable):
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
return list(zip(*result))
@staticmethod
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
current_batch = remainder
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
for s in output_list:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}

View File

@ -26,20 +26,81 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj input_data_all[x] = obj
else: else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data input_data_all[x] = [input_data]
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
h = valid_inputs["hidden"] h = valid_inputs["hidden"]
for x in h: for x in h:
if h[x] == "PROMPT": if h[x] == "PROMPT":
input_data_all[x] = prompt input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO": if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data: if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo'] input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID": if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id input_data_all[x] = [unique_id]
return input_data_all return input_data_all
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
@ -63,13 +124,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
server.send_sync("executing", { "node": unique_id }, server.client_id) server.send_sync("executing", { "node": unique_id }, server.client_id)
obj = class_def() obj = class_def()
nodes.before_node_execution() output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) outputs[unique_id] = output_data
if "ui" in outputs[unique_id]: if len(output_ui) > 0:
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id)
if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"]
executed.add(unique_id) executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item): def recursive_will_execute(prompt, outputs, current_item):
@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
input_data_all = get_input_data(inputs, class_def, unique_id, outputs) input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None: if input_data_all is not None:
try: try:
is_changed = class_def.IS_CHANGED(**input_data_all) #is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed prompt[unique_id]['is_changed'] = is_changed
except: except:
to_delete = True to_delete = True
@ -261,9 +321,11 @@ def validate_inputs(prompt, item, validated):
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id) input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all) #ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True: ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
return (False, "{}, {}".format(class_type, ret)) for r in ret:
if r != True:
return (False, "{}, {}".format(class_type, r))
else: else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:

View File

@ -629,18 +629,57 @@ class LatentFromBatch:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate" FUNCTION = "frombatch"
CATEGORY = "latent" CATEGORY = "latent/batch"
def rotate(self, samples, batch_index): def frombatch(self, samples, batch_index, length):
s = samples.copy() s = samples.copy()
s_in = samples["samples"] s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone() length = min(s_in.shape[0] - batch_index, length)
s["batch_index"] = batch_index s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,) return (s,)
class LatentUpscale: class LatentUpscale:
@ -805,8 +844,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else:
skip = latent["batch_index"] if "batch_index" in latent else 0 batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, skip) noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
@ -1170,6 +1209,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch, "LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage, "SaveImage": SaveImage,
"PreviewImage": PreviewImage, "PreviewImage": PreviewImage,
"LoadImage": LoadImage, "LoadImage": LoadImage,
@ -1244,6 +1284,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentImage": "Empty Latent Image", "EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent", "LatentUpscale": "Upscale Latent",
"LatentComposite": "Latent Composite", "LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image # Image
"SaveImage": "Save Image", "SaveImage": "Save Image",
"PreviewImage": "Preview Image", "PreviewImage": "Preview Image",
@ -1299,3 +1341,4 @@ def init_custom_nodes():
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))

View File

@ -268,6 +268,7 @@ class PromptServer():
info = {} info = {}
info['input'] = obj_class.INPUT_TYPES() info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x

View File

@ -976,7 +976,8 @@ export class ComfyApp {
for (const o in nodeData["output"]) { for (const o in nodeData["output"]) {
const output = nodeData["output"][o]; const output = nodeData["output"][o];
const outputName = nodeData["output_name"][o] || output; const outputName = nodeData["output_name"][o] || output;
this.addOutput(outputName, output); const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
} }
const s = this.computeSize(); const s = this.computeSize();