Group node fixes (#2259)

* Prevent cleaning graph state on undo/redo

* Remove pause rendering due to LG bug

* Fix crash on disconnected internal reroutes

* Fix widget inputs being incorrect order and value

* Fix initial primitive values on connect

* basic support for basic rerouted converted inputs

* Populate primitive to reroute input

* dont crash on bad primitive links

* Fix convert to group changing control value

* reduce restrictions

* fix random crash in tests
This commit is contained in:
pythongosssss 2023-12-13 05:56:39 +00:00 committed by GitHub
parent b454a67bb9
commit 390078904c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 275 additions and 27 deletions

View File

@ -1,7 +1,7 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils");
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
@ -273,7 +273,7 @@ describe("group node", () => {
let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
for (let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
@ -283,7 +283,7 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
@ -407,12 +407,18 @@ describe("group node", () => {
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
const preview = ez.PreviewImage(decode.outputs[0]);
expect((await graph.toPrompt()).output).toEqual({
const output = {
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
});
};
expect((await graph.toPrompt()).output).toEqual(output);
// Ensure missing connections dont cause errors
group2.inputs.VAE.disconnect();
delete output[decode.id].inputs.vae;
expect((await graph.toPrompt()).output).toEqual(output);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
@ -673,6 +679,55 @@ describe("group node", () => {
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.value = upscaleMethods[1];
scale1.widgets.upscale_method.convertToInput();
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
expect(group.inputs.length).toBe(3);
expect(group.inputs[0].input.type).toBe("IMAGE");
expect(group.inputs[1].input.type).toBe("IMAGE");
expect(group.inputs[2].input.type).toBe("COMBO");
// Ensure links are maintained
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[2].connection).toBeFalsy();
// Ensure primitive gets correct type
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
primitive.widgets.value.value = upscaleMethods[1];
await checkBeforeAndAfterReload(graph, async (r) => {
const scale1id = r ? `${group.id}:0` : scale1.id;
const scale2id = r ? `${group.id}:1` : scale2.id;
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
});
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
@ -846,4 +901,73 @@ describe("group node", () => {
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
test("internal reroutes work with converted inputs and merge options", async () => {
const { ez, graph, app } = await start();
const vae = ez.VAELoader();
const latent = ez.EmptyLatentImage();
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
const scale = ez.ImageScale(decode.outputs.IMAGE);
ez.PreviewImage(scale.outputs.IMAGE);
const r1 = ez.Reroute();
const r2 = ez.Reroute();
latent.widgets.width.value = 64;
latent.widgets.height.value = 128;
latent.widgets.width.convertToInput();
latent.widgets.height.convertToInput();
latent.widgets.batch_size.convertToInput();
scale.widgets.width.convertToInput();
scale.widgets.height.convertToInput();
r1.inputs[0].input.label = "hbw";
r1.outputs[0].connectTo(latent.inputs.height);
r1.outputs[0].connectTo(latent.inputs.batch_size);
r1.outputs[0].connectTo(scale.inputs.width);
r2.inputs[0].input.label = "wh";
r2.outputs[0].connectTo(latent.inputs.width);
r2.outputs[0].connectTo(scale.inputs.height);
const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
expect(group.inputs[0].input.type).toBe("VAE");
expect(group.inputs[1].input.type).toBe("INT");
expect(group.inputs[2].input.type).toBe("INT");
const p1 = ez.PrimitiveNode();
const p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(group.inputs[1]);
p2.outputs[0].connectTo(group.inputs[2]);
expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128);
expect(p2.widgets.value.value).toBe(64);
p1.widgets.value.value = 16;
p2.widgets.value.value = 32;
await checkBeforeAndAfterReload(graph, async (r) => {
const id = (v) => (r ? `${group.id}:` : "") + v;
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
[id(4)]: {
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
class_type: "ImageScale",
},
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
});
});
});
});

View File

@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
this.input = input;
}
get connection() {
const link = this.node.node.inputs?.[this.index]?.link;
if (link == null) {
return null;
}
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
}
disconnect() {
this.node.node.disconnectInput(this.index);
}

View File

@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
return { ckpt, pos, neg, empty, sampler, decode, save };
}
export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
return api.getNodeDefs();
}
export async function getNodeDef(nodeId) {
return (await getNodeDefs())[nodeId];
}

View File

@ -174,6 +174,11 @@ export class GroupNodeConfig {
node.index = i;
this.processNode(node, seenInputs, seenOutputs);
}
for (const p of this.#convertedToProcess) {
p();
}
this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
}
@ -192,7 +197,10 @@ export class GroupNodeConfig {
if (!this.linksFrom[sourceNodeId]) {
this.linksFrom[sourceNodeId] = {};
}
this.linksFrom[sourceNodeId][sourceNodeSlot] = l;
if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) {
this.linksFrom[sourceNodeId][sourceNodeSlot] = [];
}
this.linksFrom[sourceNodeId][sourceNodeSlot].push(l);
if (!this.linksTo[targetNodeId]) {
this.linksTo[targetNodeId] = {};
@ -230,11 +238,11 @@ export class GroupNodeConfig {
// Skip as its not linked
if (!linksFrom) return;
let type = linksFrom["0"][5];
let type = linksFrom["0"][0][5];
if (type === "COMBO") {
// Use the array items
const source = node.outputs[0].widget.name;
const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type;
const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type;
const fromType = globalDefs[fromTypeName];
const input = fromType.input.required[source] ?? fromType.input.optional[source];
type = input[0];
@ -258,10 +266,33 @@ export class GroupNodeConfig {
return null;
}
let config = {};
let rerouteType = "*";
if (linksFrom) {
const [, , id, slot] = linksFrom["0"];
rerouteType = this.nodeData.nodes[id].inputs[slot].type;
for (const [, , id, slot] of linksFrom["0"]) {
const node = this.nodeData.nodes[id];
const input = node.inputs[slot];
if (rerouteType === "*") {
rerouteType = input.type;
}
if (input.widget) {
const targetDef = globalDefs[node.type];
const targetWidget =
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
const widget = [targetWidget[0], config];
const res = mergeIfValid(
{
widget,
},
targetWidget,
false,
null,
widget
);
config = res?.customConfig ?? config;
}
}
} else if (linksTo) {
const [id, slot] = linksTo["0"];
rerouteType = this.nodeData.nodes[id].outputs[slot].type;
@ -282,10 +313,11 @@ export class GroupNodeConfig {
}
}
config.forceInput = true;
return {
input: {
required: {
[rerouteType]: [rerouteType, {}],
[rerouteType]: [rerouteType, config],
},
},
output: [rerouteType],
@ -420,10 +452,18 @@ export class GroupNodeConfig {
defaultInput: true,
});
this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName };
if (!this.oldToNewWidgetMap[node.index]) {
this.oldToNewWidgetMap[node.index] = {};
}
this.oldToNewWidgetMap[node.index][inputName] = name;
inputMap[slots.length + i] = this.inputCount++;
}
}
#convertedToProcess = [];
processNodeInputs(node, seenInputs, inputs) {
const inputMapping = [];
@ -434,7 +474,11 @@ export class GroupNodeConfig {
const linksTo = this.linksTo[node.index] ?? {};
const inputMap = (this.oldToNewInputMap[node.index] = {});
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
this.#convertedToProcess.push(() =>
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
);
return inputMapping;
}
@ -597,11 +641,15 @@ export class GroupNodeHandler {
const output = this.groupData.newToOldOutputMap[link.origin_slot];
let innerNode = this.innerNodes[output.node.index];
let l;
while (innerNode.type === "Reroute") {
while (innerNode?.type === "Reroute") {
l = innerNode.getInputLink(0);
innerNode = innerNode.getInputNode(0);
}
if (!innerNode) {
return null;
}
if (l && GroupNodeHandler.isGroupNode(innerNode)) {
return innerNode.updateLink(l);
}
@ -669,6 +717,8 @@ export class GroupNodeHandler {
top = newNode.pos[1];
}
if (!newNode.widgets) continue;
const map = this.groupData.oldToNewWidgetMap[innerNode.index];
if (map) {
const widgets = Object.keys(map);
@ -725,7 +775,7 @@ export class GroupNodeHandler {
}
};
const reconnectOutputs = () => {
const reconnectOutputs = (selectedIds) => {
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
const output = node.outputs[groupOutputId];
if (!output.links) continue;
@ -865,7 +915,7 @@ export class GroupNodeHandler {
if (innerNode.type === "PrimitiveNode") {
innerNode.primitiveValue = newValue;
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
for (const linked of primitiveLinked) {
for (const linked of primitiveLinked ?? []) {
const node = this.innerNodes[linked.nodeId];
const widget = node.widgets.find((w) => w.name === linked.inputName);
@ -874,6 +924,18 @@ export class GroupNodeHandler {
}
}
continue;
} else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index];
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot];
if (input.widget) {
const widget = node.widgets?.find((w) => w.name === input.widget.name);
if (widget) {
widget.value = newValue;
}
}
}
}
const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
@ -901,33 +963,58 @@ export class GroupNodeHandler {
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
}
}
return true;
}
populateReroute(node, nodeId, map) {
if (node.type !== "Reroute") return;
const link = this.groupData.linksFrom[nodeId]?.[0]?.[0];
if (!link) return;
const [, , targetNodeId, targetNodeSlot] = link;
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
const inputs = targetNode.inputs;
const targetWidget = inputs?.[targetNodeSlot].widget;
if (!targetWidget) return;
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
const v = targetNode.widgets_values?.[targetNodeSlot - offset];
if (v == null) return;
const widgetName = Object.values(map)[0];
const widget = this.node.widgets.find(w => w.name === widgetName);
if(widget) {
widget.value = v;
}
}
populateWidgets() {
if (!this.node.widgets) return;
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
const node = this.groupData.nodeData.nodes[nodeId];
if (!node.widgets_values?.length) continue;
const map = this.groupData.oldToNewWidgetMap[nodeId];
const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {};
const widgets = Object.keys(map);
if (!node.widgets_values?.length) {
// special handling for populating values into reroutes
// this allows primitives connect to them to pick up the correct value
this.populateReroute(node, nodeId, map);
continue;
}
let linkedShift = 0;
for (let i = 0; i < widgets.length; i++) {
const oldName = widgets[i];
const newName = map[oldName];
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
const mainWidget = this.node.widgets[widgetIndex];
if (!newName) {
// New name will be null if its a converted widget
this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift)) {
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
linkedShift += innerWidget.linkedWidgets?.length ?? 0;
continue;
}
if (widgetIndex === -1) {
continue;
}

View File

@ -54,6 +54,7 @@ app.registerExtension({
const linkId = currentNode.inputs[0].link;
if (linkId !== null) {
const link = app.graph.links[linkId];
if (!link) return;
const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type;
if (type === "Reroute") {

View File

@ -180,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
@ -633,6 +633,14 @@ app.registerExtension({
}
}
// Restore any saved control values
const controlValues = this.controlValues;
if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
for(let i = 0; i < controlValues.length; i++) {
this.widgets[i + 1].value = controlValues[i];
}
}
// When our value changes, update other widgets to reflect our changes
// e.g. so LoadImage shows correct image
const callback = widget.callback;
@ -721,6 +729,15 @@ app.registerExtension({
w.onRemove();
}
}
// Temporarily store the current values in case the node is being recreated
// e.g. by group node conversion
this.controlValues = [];
this.lastType = this.widgets[0]?.type;
for(let i = 1; i < this.widgets.length; i++) {
this.controlValues.push(this.widgets[i].value);
}
setTimeout(() => { delete this.lastType; delete this.controlValues }, 15);
this.widgets.length = 0;
}
}

View File

@ -1774,7 +1774,9 @@ export class ComfyApp {
if (parent?.updateLink) {
link = parent.updateLink(link);
}
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
if (link) {
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
}
}
}