Make nodes map over input lists (#579)
* allow nodes to map over lists * make work with IS_CHANGED and VALIDATE_INPUTS * give list outputs distinct socket shape * add rebatch node * add batch index logic * add repeat latent batch * deal with noise mask edge cases in latentfrombatch
This commit is contained in:
parent
997dd1b131
commit
1201d2eae5
|
@ -2,17 +2,26 @@ import torch
|
|||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def prepare_noise(latent_image, seed, skip=0):
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
for _ in range(skip):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
return noise
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
import torch
|
||||
|
||||
class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_IS_LIST = (True, )
|
||||
|
||||
FUNCTION = "rebatch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
@staticmethod
|
||||
def get_batch(latents, list_ind, offset):
|
||||
'''prepare a batch out of the list of latents'''
|
||||
samples = latents[list_ind]['samples']
|
||||
shape = samples.shape
|
||||
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||
if mask.shape[0] < samples.shape[0]:
|
||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||
if 'batch_index' in latents[list_ind]:
|
||||
batch_inds = latents[list_ind]['batch_index']
|
||||
else:
|
||||
batch_inds = [x+offset for x in range(shape[0])]
|
||||
return samples, mask, batch_inds
|
||||
|
||||
@staticmethod
|
||||
def get_slices(indexable, num, batch_size):
|
||||
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||
slices = []
|
||||
for i in range(num):
|
||||
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||
if num * batch_size < len(indexable):
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
return list(zip(*result))
|
||||
|
||||
@staticmethod
|
||||
def cat_batch(batch1, batch2):
|
||||
if batch1[0] is None:
|
||||
return batch2
|
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
return result
|
||||
|
||||
def rebatch(self, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
current_batch = (None, None, None)
|
||||
processed = 0
|
||||
|
||||
for i in range(len(latents)):
|
||||
# fetch new entry of list
|
||||
#samples, masks, indices = self.get_batch(latents, i)
|
||||
next_batch = self.get_batch(latents, i, processed)
|
||||
processed += len(next_batch[2])
|
||||
# set to current if current is None
|
||||
if current_batch[0] is None:
|
||||
current_batch = next_batch
|
||||
# add previous to list if dimensions do not match
|
||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
current_batch = next_batch
|
||||
# cat if everything checks out
|
||||
else:
|
||||
current_batch = self.cat_batch(current_batch, next_batch)
|
||||
|
||||
# add to list if dimensions gone above target batch size
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
current_batch = remainder
|
||||
|
||||
#add remainder
|
||||
if current_batch[0] is not None:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
|
||||
#get rid of empty masks
|
||||
for s in output_list:
|
||||
if s['noise_mask'].mean() == 1.0:
|
||||
del s['noise_mask']
|
||||
|
||||
return (output_list,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RebatchLatents": LatentRebatch,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RebatchLatents": "Rebatch Latents",
|
||||
}
|
90
execution.py
90
execution.py
|
@ -26,20 +26,81 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = input_data
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = prompt
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = extra_data['extra_pnginfo']
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = unique_id
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
results = []
|
||||
if intput_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
else:
|
||||
results.append(r)
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
|
@ -63,13 +124,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
nodes.before_node_execution()
|
||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||
if "ui" in outputs[unique_id]:
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
if "result" in outputs[unique_id]:
|
||||
outputs[unique_id] = outputs[unique_id]["result"]
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
|
@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
|
@ -261,9 +321,11 @@ def validate_inputs(prompt, item, validated):
|
|||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
if ret != True:
|
||||
return (False, "{}, {}".format(class_type, ret))
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
return (False, "{}, {}".format(class_type, r))
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
|
|
57
nodes.py
57
nodes.py
|
@ -629,18 +629,57 @@ class LatentFromBatch:
|
|||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "rotate"
|
||||
FUNCTION = "frombatch"
|
||||
|
||||
CATEGORY = "latent"
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def rotate(self, samples, batch_index):
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
||||
s["batch_index"] = batch_index
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] == 1:
|
||||
s["noise_mask"] = masks.clone()
|
||||
else:
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
|
||||
if "batch_index" not in s:
|
||||
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||
if "batch_index" in s:
|
||||
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
|
@ -805,8 +844,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
|
@ -1170,6 +1209,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"LatentFromBatch": LatentFromBatch,
|
||||
"RepeatLatentBatch": RepeatLatentBatch,
|
||||
"SaveImage": SaveImage,
|
||||
"PreviewImage": PreviewImage,
|
||||
"LoadImage": LoadImage,
|
||||
|
@ -1244,6 +1284,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"EmptyLatentImage": "Empty Latent Image",
|
||||
"LatentUpscale": "Upscale Latent",
|
||||
"LatentComposite": "Latent Composite",
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
# Image
|
||||
"SaveImage": "Save Image",
|
||||
"PreviewImage": "Preview Image",
|
||||
|
@ -1299,3 +1341,4 @@ def init_custom_nodes():
|
|||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||
|
|
|
@ -268,6 +268,7 @@ class PromptServer():
|
|||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = x
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
||||
|
|
|
@ -976,7 +976,8 @@ export class ComfyApp {
|
|||
for (const o in nodeData["output"]) {
|
||||
const output = nodeData["output"][o];
|
||||
const outputName = nodeData["output_name"][o] || output;
|
||||
this.addOutput(outputName, output);
|
||||
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||
this.addOutput(outputName, output, { shape: outputShape });
|
||||
}
|
||||
|
||||
const s = this.computeSize();
|
||||
|
|
Loading…
Reference in New Issue