diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 00000000..202121e1
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,9 @@
+{
+ "path-intellisense.mappings": {
+ "../": "${workspaceFolder}/web/extensions/core"
+ },
+ "[python]": {
+ "editor.defaultFormatter": "ms-python.autopep8"
+ },
+ "python.formatting.provider": "none"
+}
diff --git a/tests-ui/setup.js b/tests-ui/setup.js
index 0f368ab2..8bbd9dcd 100644
--- a/tests-ui/setup.js
+++ b/tests-ui/setup.js
@@ -20,6 +20,7 @@ async function setup() {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
+ objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");
diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js
new file mode 100644
index 00000000..ce54c115
--- /dev/null
+++ b/tests-ui/tests/groupNode.test.js
@@ -0,0 +1,818 @@
+// @ts-check
+///
+
+const { start, createDefaultWorkflow } = require("../utils");
+const lg = require("../utils/litegraph");
+
+describe("group node", () => {
+ beforeEach(() => {
+ lg.setup(global);
+ });
+
+ afterEach(() => {
+ lg.teardown(global);
+ });
+
+ /**
+ *
+ * @param {*} app
+ * @param {*} graph
+ * @param {*} name
+ * @param {*} nodes
+ * @returns { Promise> }
+ */
+ async function convertToGroup(app, graph, name, nodes) {
+ // Select the nodes we are converting
+ for (const n of nodes) {
+ n.select(true);
+ }
+
+ expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
+ nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
+ );
+
+ global.prompt = jest.fn().mockImplementation(() => name);
+ const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
+
+ // Check group name was requested
+ expect(window.prompt).toHaveBeenCalled();
+
+ // Ensure old nodes are removed
+ for (const n of nodes) {
+ expect(n.isRemoved).toBeTruthy();
+ }
+
+ expect(groupNode.type).toEqual("workflow/" + name);
+
+ return graph.find(groupNode);
+ }
+
+ /**
+ * @param { Record | number[] } idMap
+ * @param { Record> } valueMap
+ */
+ function getOutput(idMap = {}, valueMap = {}) {
+ if (idMap instanceof Array) {
+ idMap = idMap.reduce((p, n) => {
+ p[n] = n + "";
+ return p;
+ }, {});
+ }
+ const expected = {
+ 1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
+ 2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
+ 3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
+ 4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
+ 5: {
+ inputs: {
+ seed: 0,
+ steps: 20,
+ cfg: 8,
+ sampler_name: "euler",
+ scheduler: "normal",
+ denoise: 1,
+ model: ["1", 0],
+ positive: ["2", 0],
+ negative: ["3", 0],
+ latent_image: ["4", 0],
+ ...valueMap?.[5],
+ },
+ class_type: "KSampler",
+ },
+ 6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
+ 7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
+ };
+
+ // Map old IDs to new at the top level
+ const mapped = {};
+ for (const oldId in idMap) {
+ mapped[idMap[oldId]] = expected[oldId];
+ delete expected[oldId];
+ }
+ Object.assign(mapped, expected);
+
+ // Map old IDs to new inside links
+ for (const k in mapped) {
+ for (const input in mapped[k].inputs) {
+ const v = mapped[k].inputs[input];
+ if (v instanceof Array) {
+ if (v[0] in idMap) {
+ v[0] = idMap[v[0]] + "";
+ }
+ }
+ }
+ }
+
+ return mapped;
+ }
+
+ test("can be created from selected nodes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
+
+ // Ensure links are now to the group node
+ expect(group.inputs).toHaveLength(2);
+ expect(group.outputs).toHaveLength(3);
+
+ expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
+ expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
+
+ // ckpt clip to both clip inputs on the group
+ expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [group.id, 0],
+ [group.id, 1],
+ ]);
+
+ // group conditioning to sampler
+ expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 1],
+ ]);
+ // group conditioning 2 to sampler
+ expect(
+ group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
+ ).toEqual([[nodes.sampler.id, 2]]);
+ // group latent to sampler
+ expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 3],
+ ]);
+ });
+
+ test("maintains all output links on conversion", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const save2 = ez.SaveImage(...nodes.decode.outputs);
+ const save3 = ez.SaveImage(...nodes.decode.outputs);
+ // Ensure an output with multiple links maintains them on convert to group
+ const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]);
+ expect(group.outputs[0].connections.length).toBe(3);
+ expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
+ expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id);
+ expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id);
+
+ // and they're still linked when converting back to nodes
+ const newNodes = group.menu["Convert to nodes"].call();
+ const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode"));
+ expect(decode.outputs[0].connections.length).toBe(3);
+ expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
+ expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
+ expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
+ });
+ test("can be be converted back to nodes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
+ const group = await convertToGroup(app, graph, "test", toConvert);
+
+ // Edit some values to ensure they are set back onto the converted nodes
+ expect(group.widgets["text"].value).toBe("positive");
+ group.widgets["text"].value = "pos";
+ expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
+ group.widgets["CLIPTextEncode text"].value = "neg";
+ expect(group.widgets["width"].value).toBe(512);
+ group.widgets["width"].value = 1024;
+ expect(group.widgets["sampler_name"].value).toBe("euler");
+ group.widgets["sampler_name"].value = "ddim";
+ expect(group.widgets["control_after_generate"].value).toBe("randomize");
+ group.widgets["control_after_generate"].value = "fixed";
+
+ /** @type { Array } */
+ group.menu["Convert to nodes"].call();
+
+ // ensure widget values are set
+ const pos = graph.find(nodes.pos.id);
+ expect(pos.node.type).toBe("CLIPTextEncode");
+ expect(pos.widgets["text"].value).toBe("pos");
+ const neg = graph.find(nodes.neg.id);
+ expect(neg.node.type).toBe("CLIPTextEncode");
+ expect(neg.widgets["text"].value).toBe("neg");
+ const empty = graph.find(nodes.empty.id);
+ expect(empty.node.type).toBe("EmptyLatentImage");
+ expect(empty.widgets["width"].value).toBe(1024);
+ const sampler = graph.find(nodes.sampler.id);
+ expect(sampler.node.type).toBe("KSampler");
+ expect(sampler.widgets["sampler_name"].value).toBe("ddim");
+ expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
+
+ // validate links
+ expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [pos.id, 0],
+ [neg.id, 0],
+ ]);
+
+ expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 1],
+ ]);
+
+ expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 2],
+ ]);
+
+ expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 3],
+ ]);
+ });
+ test("it can embed reroutes as inputs", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ // Add and connect a reroute to the clip text encodes
+ const reroute = ez.Reroute();
+ nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
+ reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
+ reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
+
+ // Convert to group and ensure we only have 1 input of the correct type
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
+ expect(group.inputs).toHaveLength(1);
+ expect(group.inputs[0].input.type).toEqual("CLIP");
+
+ expect((await graph.toPrompt()).output).toEqual(getOutput());
+ });
+ test("it can embed reroutes as outputs", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ // Add a reroute with no output so we output IMAGE even though its used internally
+ const reroute = ez.Reroute();
+ nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
+
+ // Convert to group and ensure there is an IMAGE output
+ const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
+ expect(group.outputs).toHaveLength(1);
+ expect(group.outputs[0].output.type).toEqual("IMAGE");
+ expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
+ });
+ test("it can embed reroutes as pipes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ // Use reroutes as a pipe
+ const rerouteModel = ez.Reroute();
+ const rerouteClip = ez.Reroute();
+ const rerouteVae = ez.Reroute();
+ nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
+ nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
+ nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
+
+ const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
+
+ expect(group.outputs).toHaveLength(3);
+ expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
+
+ expect(group.outputs).toHaveLength(3);
+ expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
+
+ group.outputs[0].connectTo(nodes.sampler.inputs.model);
+ group.outputs[1].connectTo(nodes.pos.inputs.clip);
+ group.outputs[1].connectTo(nodes.neg.inputs.clip);
+ });
+ test("can handle reroutes used internally", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ let reroutes = [];
+ let prevNode = nodes.ckpt;
+ for(let i = 0; i < 5; i++) {
+ const reroute = ez.Reroute();
+ prevNode.outputs[0].connectTo(reroute.inputs[0]);
+ prevNode = reroute;
+ reroutes.push(reroute);
+ }
+ prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
+
+ const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
+ expect((await graph.toPrompt()).output).toEqual(getOutput());
+
+ group.menu["Convert to nodes"].call();
+ expect((await graph.toPrompt()).output).toEqual(getOutput());
+ });
+ test("creates with widget values from inner nodes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
+ nodes.pos.widgets.text.value = "hello";
+ nodes.neg.widgets.text.value = "world";
+ nodes.empty.widgets.width.value = 256;
+ nodes.empty.widgets.height.value = 1024;
+ nodes.sampler.widgets.seed.value = 1;
+ nodes.sampler.widgets.control_after_generate.value = "increment";
+ nodes.sampler.widgets.steps.value = 8;
+ nodes.sampler.widgets.cfg.value = 4.5;
+ nodes.sampler.widgets.sampler_name.value = "uni_pc";
+ nodes.sampler.widgets.scheduler.value = "karras";
+ nodes.sampler.widgets.denoise.value = 0.9;
+
+ const group = await convertToGroup(app, graph, "test", [
+ nodes.ckpt,
+ nodes.pos,
+ nodes.neg,
+ nodes.empty,
+ nodes.sampler,
+ ]);
+
+ expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
+ expect(group.widgets["text"].value).toEqual("hello");
+ expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
+ expect(group.widgets["width"].value).toEqual(256);
+ expect(group.widgets["height"].value).toEqual(1024);
+ expect(group.widgets["seed"].value).toEqual(1);
+ expect(group.widgets["control_after_generate"].value).toEqual("increment");
+ expect(group.widgets["steps"].value).toEqual(8);
+ expect(group.widgets["cfg"].value).toEqual(4.5);
+ expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
+ expect(group.widgets["scheduler"].value).toEqual("karras");
+ expect(group.widgets["denoise"].value).toEqual(0.9);
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
+ [nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
+ [nodes.pos.id]: { text: "hello" },
+ [nodes.neg.id]: { text: "world" },
+ [nodes.empty.id]: { width: 256, height: 1024 },
+ [nodes.sampler.id]: {
+ seed: 1,
+ steps: 8,
+ cfg: 4.5,
+ sampler_name: "uni_pc",
+ scheduler: "karras",
+ denoise: 0.9,
+ },
+ })
+ );
+ });
+ test("group inputs can be reroutes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+
+ const reroute = ez.Reroute();
+ nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
+
+ reroute.outputs[0].connectTo(group.inputs[0]);
+ reroute.outputs[0].connectTo(group.inputs[1]);
+
+ expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
+ });
+ test("group outputs can be reroutes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+
+ const reroute1 = ez.Reroute();
+ const reroute2 = ez.Reroute();
+ group.outputs[0].connectTo(reroute1.inputs[0]);
+ group.outputs[1].connectTo(reroute2.inputs[0]);
+
+ reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
+ reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
+
+ expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
+ });
+ test("groups can connect to each other", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+ const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
+
+ group1.outputs[0].connectTo(group2.inputs["positive"]);
+ group1.outputs[1].connectTo(group2.inputs["negative"]);
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
+ );
+ });
+ test("displays generated image on group node", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ let group = await convertToGroup(app, graph, "test", [
+ nodes.pos,
+ nodes.neg,
+ nodes.empty,
+ nodes.sampler,
+ nodes.decode,
+ nodes.save,
+ ]);
+
+ const { api } = require("../../web/scripts/api");
+
+ api.dispatchEvent(new CustomEvent("execution_start", {}));
+ api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
+ // Event should be forwarded to group node id
+ expect(+app.runningNodeId).toEqual(group.id);
+ expect(group.node["imgs"]).toBeFalsy();
+ api.dispatchEvent(
+ new CustomEvent("executed", {
+ detail: {
+ node: `${nodes.save.id}`,
+ output: {
+ images: [
+ {
+ filename: "test.png",
+ type: "output",
+ },
+ ],
+ },
+ },
+ })
+ );
+
+ // Trigger paint
+ group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
+
+ expect(group.node["images"]).toEqual([
+ {
+ filename: "test.png",
+ type: "output",
+ },
+ ]);
+
+ // Reload
+ const workflow = JSON.stringify((await graph.toPrompt()).workflow);
+ await app.loadGraphData(JSON.parse(workflow));
+ group = graph.find(group);
+
+ // Trigger inner nodes to get created
+ group.node["getInnerNodes"]();
+
+ // Check it works for internal node ids
+ api.dispatchEvent(new CustomEvent("execution_start", {}));
+ api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
+ // Event should be forwarded to group node id
+ expect(+app.runningNodeId).toEqual(group.id);
+ expect(group.node["imgs"]).toBeFalsy();
+ api.dispatchEvent(
+ new CustomEvent("executed", {
+ detail: {
+ node: `${group.id}:5`,
+ output: {
+ images: [
+ {
+ filename: "test2.png",
+ type: "output",
+ },
+ ],
+ },
+ },
+ })
+ );
+
+ // Trigger paint
+ group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
+
+ expect(group.node["images"]).toEqual([
+ {
+ filename: "test2.png",
+ type: "output",
+ },
+ ]);
+ });
+ test("allows widgets to be converted to inputs", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+ group.widgets[0].convertToInput();
+
+ const primitive = ez.PrimitiveNode();
+ primitive.outputs[0].connectTo(group.inputs["text"]);
+ primitive.widgets[0].value = "hello";
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput([nodes.pos.id, nodes.neg.id], {
+ [nodes.pos.id]: { text: "hello" },
+ })
+ );
+ });
+ test("can be copied", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ const group1 = await convertToGroup(app, graph, "test", [
+ nodes.pos,
+ nodes.neg,
+ nodes.empty,
+ nodes.sampler,
+ nodes.decode,
+ nodes.save,
+ ]);
+
+ group1.widgets["text"].value = "hello";
+ group1.widgets["width"].value = 256;
+ group1.widgets["seed"].value = 1;
+
+ // Clone the node
+ group1.menu.Clone.call();
+ expect(app.graph._nodes).toHaveLength(3);
+ const group2 = graph.find(app.graph._nodes[2]);
+ expect(group2.node.type).toEqual("workflow/test");
+ expect(group2.id).not.toEqual(group1.id);
+
+ // Reconnect ckpt
+ nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
+ nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
+ nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
+ nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
+
+ group2.widgets["text"].value = "world";
+ group2.widgets["width"].value = 1024;
+ group2.widgets["seed"].value = 100;
+
+ let i = 0;
+ expect((await graph.toPrompt()).output).toEqual({
+ ...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
+ [nodes.empty.id]: { width: 256 },
+ [nodes.pos.id]: { text: "hello" },
+ [nodes.sampler.id]: { seed: 1 },
+ }),
+ ...getOutput(
+ {
+ [nodes.empty.id]: `${group2.id}:${i++}`,
+ [nodes.pos.id]: `${group2.id}:${i++}`,
+ [nodes.neg.id]: `${group2.id}:${i++}`,
+ [nodes.sampler.id]: `${group2.id}:${i++}`,
+ [nodes.decode.id]: `${group2.id}:${i++}`,
+ [nodes.save.id]: `${group2.id}:${i++}`,
+ },
+ {
+ [nodes.empty.id]: { width: 1024 },
+ [nodes.pos.id]: { text: "world" },
+ [nodes.sampler.id]: { seed: 100 },
+ }
+ ),
+ });
+
+ graph.arrange();
+ });
+ test("is embedded in workflow", async () => {
+ let { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+ const workflow = JSON.stringify((await graph.toPrompt()).workflow);
+
+ // Clear the environment
+ ({ ez, graph, app } = await start({
+ resetEnv: true,
+ }));
+ // Ensure the node isnt registered
+ expect(() => ez["workflow/test"]).toThrow();
+
+ // Reload the workflow
+ await app.loadGraphData(JSON.parse(workflow));
+
+ // Ensure the node is found
+ group = graph.find(group);
+
+ // Generate prompt and ensure it is as expected
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.pos.id]: `${group.id}:0`,
+ [nodes.neg.id]: `${group.id}:1`,
+ })
+ );
+ });
+ test("shows missing node error on missing internal node when loading graph data", async () => {
+ const { graph } = await start();
+
+ const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
+ await graph.app.loadGraphData({
+ last_node_id: 3,
+ last_link_id: 1,
+ nodes: [
+ {
+ id: 3,
+ type: "workflow/testerror",
+ },
+ ],
+ links: [],
+ groups: [],
+ config: {},
+ extra: {
+ groupNodes: {
+ testerror: {
+ nodes: [
+ {
+ type: "NotKSampler",
+ },
+ {
+ type: "NotVAEDecode",
+ },
+ ],
+ },
+ },
+ },
+ });
+
+ expect(dialogShow).toBeCalledTimes(1);
+ const call = dialogShow.mock.calls[0][0].innerHTML;
+ expect(call).toContain("the following node types were not found");
+ expect(call).toContain("NotKSampler");
+ expect(call).toContain("NotVAEDecode");
+ expect(call).toContain("workflow/testerror");
+ });
+ test("maintains widget inputs on conversion back to nodes", async () => {
+ const { ez, graph, app } = await start();
+ let pos = ez.CLIPTextEncode({ text: "positive" });
+ pos.node.title = "Positive";
+ let neg = ez.CLIPTextEncode({ text: "negative" });
+ neg.node.title = "Negative";
+ pos.widgets.text.convertToInput();
+ neg.widgets.text.convertToInput();
+
+ let primitive = ez.PrimitiveNode();
+ primitive.outputs[0].connectTo(pos.inputs.text);
+ primitive.outputs[0].connectTo(neg.inputs.text);
+
+ const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
+ // This will use a primitive widget named 'value'
+ expect(group.widgets.length).toBe(1);
+ expect(group.widgets["value"].value).toBe("positive");
+
+ const newNodes = group.menu["Convert to nodes"].call();
+ pos = graph.find(newNodes.find((n) => n.title === "Positive"));
+ neg = graph.find(newNodes.find((n) => n.title === "Negative"));
+ primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode"));
+
+ expect(pos.inputs).toHaveLength(2);
+ expect(neg.inputs).toHaveLength(2);
+ expect(primitive.outputs[0].connections).toHaveLength(2);
+
+ expect((await graph.toPrompt()).output).toEqual({
+ 1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
+ 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
+ });
+ });
+ test("adds widgets in node execution order", async () => {
+ const { ez, graph, app } = await start();
+ const scale = ez.LatentUpscale();
+ const save = ez.SaveImage();
+ const empty = ez.EmptyLatentImage();
+ const decode = ez.VAEDecode();
+
+ scale.outputs.LATENT.connectTo(decode.inputs.samples);
+ decode.outputs.IMAGE.connectTo(save.inputs.images);
+ empty.outputs.LATENT.connectTo(scale.inputs.samples);
+
+ const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
+ const widgets = group.widgets.map((w) => w.widget.name);
+ expect(widgets).toStrictEqual([
+ "width",
+ "height",
+ "batch_size",
+ "upscale_method",
+ "LatentUpscale width",
+ "LatentUpscale height",
+ "crop",
+ "filename_prefix",
+ ]);
+ });
+ test("adds output for external links when converting to group", async () => {
+ const { ez, graph, app } = await start();
+ const img = ez.EmptyLatentImage();
+ let decode = ez.VAEDecode(...img.outputs);
+ const preview1 = ez.PreviewImage(...decode.outputs);
+ const preview2 = ez.PreviewImage(...decode.outputs);
+
+ const group = await convertToGroup(app, graph, "test", [img, decode, preview1]);
+
+ // Ensure we have an output connected to the 2nd preview node
+ expect(group.outputs.length).toBe(1);
+ expect(group.outputs[0].connections.length).toBe(1);
+ expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id);
+
+ // Convert back and ensure bothe previews are still connected
+ group.menu["Convert to nodes"].call();
+ decode = graph.find(decode);
+ expect(decode.outputs[0].connections.length).toBe(2);
+ expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id);
+ expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id);
+ });
+ test("adds output for external links when converting to group when nodes are not in execution order", async () => {
+ const { ez, graph, app } = await start();
+ const sampler = ez.KSampler();
+ const ckpt = ez.CheckpointLoaderSimple();
+ const empty = ez.EmptyLatentImage();
+ const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
+ const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
+ const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
+ const save = ez.SaveImage(decode1.outputs.IMAGE);
+ ckpt.outputs.MODEL.connectTo(sampler.inputs.model);
+ pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive);
+ neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative);
+ empty.outputs.LATENT.connectTo(sampler.inputs.latent_image);
+
+ const encode = ez.VAEEncode(decode1.outputs.IMAGE);
+ const vae = ez.VAELoader();
+ const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE);
+ const preview = ez.PreviewImage(decode2.outputs.IMAGE);
+ vae.outputs.VAE.connectTo(encode.inputs.vae);
+
+ const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
+
+ expect(group.outputs.length).toBe(3);
+ expect(group.outputs[0].output.name).toBe("VAE");
+ expect(group.outputs[0].output.type).toBe("VAE");
+ expect(group.outputs[1].output.name).toBe("IMAGE");
+ expect(group.outputs[1].output.type).toBe("IMAGE");
+ expect(group.outputs[2].output.name).toBe("LATENT");
+ expect(group.outputs[2].output.type).toBe("LATENT");
+
+ expect(group.outputs[0].connections.length).toBe(1);
+ expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id);
+ expect(group.outputs[0].connections[0].targetInput.index).toBe(1);
+
+ expect(group.outputs[1].connections.length).toBe(1);
+ expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id);
+ expect(group.outputs[1].connections[0].targetInput.index).toBe(0);
+
+ expect(group.outputs[2].connections.length).toBe(1);
+ expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id);
+ expect(group.outputs[2].connections[0].targetInput.index).toBe(0);
+
+ expect((await graph.toPrompt()).output).toEqual({
+ ...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }),
+ [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type },
+ [encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type },
+ [decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type },
+ [preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type },
+ });
+ });
+ test("works with IMAGEUPLOAD widget", async () => {
+ const { ez, graph, app } = await start();
+ const img = ez.LoadImage();
+ const preview1 = ez.PreviewImage(img.outputs[0]);
+
+ const group = await convertToGroup(app, graph, "test", [img, preview1]);
+ const widget = group.widgets["upload"];
+ expect(widget).toBeTruthy();
+ expect(widget.widget.type).toBe("button");
+ });
+ test("internal primitive populates widgets for all linked inputs", async () => {
+ const { ez, graph, app } = await start();
+ const img = ez.LoadImage();
+ const scale1 = ez.ImageScale(img.outputs[0]);
+ const scale2 = ez.ImageScale(img.outputs[0]);
+ ez.PreviewImage(scale1.outputs[0]);
+ ez.PreviewImage(scale2.outputs[0]);
+
+ scale1.widgets.width.convertToInput();
+ scale2.widgets.height.convertToInput();
+
+ const primitive = ez.PrimitiveNode();
+ primitive.outputs[0].connectTo(scale1.inputs.width);
+ primitive.outputs[0].connectTo(scale2.inputs.height);
+
+ const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]);
+ group.widgets.value.value = 100;
+ expect((await graph.toPrompt()).output).toEqual({
+ 1: {
+ inputs: { image: img.widgets.image.value, upload: "image" },
+ class_type: "LoadImage",
+ },
+ 2: {
+ inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] },
+ class_type: "ImageScale",
+ },
+ 3: {
+ inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] },
+ class_type: "ImageScale",
+ },
+ 4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" },
+ 5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" },
+ });
+ });
+ test("primitive control widgets values are copied on convert", async () => {
+ const { ez, graph, app } = await start();
+ const sampler = ez.KSampler();
+ sampler.widgets.seed.convertToInput();
+ sampler.widgets.sampler_name.convertToInput();
+
+ let p1 = ez.PrimitiveNode();
+ let p2 = ez.PrimitiveNode();
+ p1.outputs[0].connectTo(sampler.inputs.seed);
+ p2.outputs[0].connectTo(sampler.inputs.sampler_name);
+
+ p1.widgets.control_after_generate.value = "increment";
+ p2.widgets.control_after_generate.value = "decrement";
+ p2.widgets.control_filter_list.value = "/.*/";
+
+ p2.node.title = "p2";
+
+ const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]);
+ expect(group.widgets.control_after_generate.value).toBe("increment");
+ expect(group.widgets["p2 control_after_generate"].value).toBe("decrement");
+ expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/");
+
+ group.widgets.control_after_generate.value = "fixed";
+ group.widgets["p2 control_after_generate"].value = "randomize";
+ group.widgets["p2 control_filter_list"].value = "/.+/";
+
+ group.menu["Convert to nodes"].call();
+ p1 = graph.find(p1);
+ p2 = graph.find(p2);
+
+ expect(p1.widgets.control_after_generate.value).toBe("fixed");
+ expect(p2.widgets.control_after_generate.value).toBe("randomize");
+ expect(p2.widgets.control_filter_list.value).toBe("/.+/");
+ });
+});
diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js
index e1873105..8e191adf 100644
--- a/tests-ui/tests/widgetInputs.test.js
+++ b/tests-ui/tests/widgetInputs.test.js
@@ -202,8 +202,8 @@ describe("widget inputs", () => {
});
expect(dialogShow).toBeCalledTimes(1);
- expect(dialogShow.mock.calls[0][0]).toContain("the following node types were not found");
- expect(dialogShow.mock.calls[0][0]).toContain("TestNode");
+ expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
+ expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
});
test("defaultInput widgets can be converted back to inputs", async () => {
diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js
index 0e81fd47..898b82db 100644
--- a/tests-ui/utils/ezgraph.js
+++ b/tests-ui/utils/ezgraph.js
@@ -150,7 +150,7 @@ export class EzNodeMenuItem {
if (selectNode) {
this.node.select();
}
- this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
+ return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
}
}
@@ -240,8 +240,12 @@ export class EzNode {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
}
- select() {
- this.app.canvas.selectNode(this.node);
+ get isRemoved() {
+ return !this.app.graph.getNodeById(this.id);
+ }
+
+ select(addToSelection = false) {
+ this.app.canvas.selectNode(this.node, addToSelection);
}
// /**
@@ -275,12 +279,17 @@ export class EzNode {
if (!s) return p;
const name = s[nameProperty];
+ const item = new ctor(this, i, s);
// @ts-ignore
- if (!name || name in p) {
- throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
+ p.push(item);
+ if (name) {
+ // @ts-ignore
+ if (name in p) {
+ throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
+ }
}
// @ts-ignore
- p.push((p[name] = new ctor(this, i, s)));
+ p[name] = item;
return p;
}, Object.assign([], { $: this }));
}
@@ -348,6 +357,19 @@ export class EzGraph {
}, 10);
});
}
+
+ /**
+ * @returns { Promise<{
+ * workflow: {},
+ * output: Record
+ * }>}> }
+ */
+ toPrompt() {
+ // @ts-ignore
+ return this.app.graphToPrompt();
+ }
}
export const Ez = {
@@ -356,12 +378,12 @@ export const Ez = {
* @example
* const { ez, graph } = Ez.graph(app);
* graph.clear();
- * const [model, clip, vae] = ez.CheckpointLoaderSimple();
- * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" });
- * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" });
- * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage());
- * const [image] = ez.VAEDecode(latent, vae);
- * const saveNode = ez.SaveImage(image).node;
+ * const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
+ * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
+ * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
+ * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
+ * const [image] = ez.VAEDecode(latent, vae).outputs;
+ * const saveNode = ez.SaveImage(image);
* console.log(saveNode);
* graph.arrange();
* @param { app } app
diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js
index 01c58b21..eeccdb3d 100644
--- a/tests-ui/utils/index.js
+++ b/tests-ui/utils/index.js
@@ -1,21 +1,28 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
+const lg = require("./litegraph");
/**
*
- * @param { Parameters[0] } config
+ * @param { Parameters[0] & { resetEnv?: boolean } } config
* @returns
*/
export async function start(config = undefined) {
+ if(config?.resetEnv) {
+ jest.resetModules();
+ jest.resetAllMocks();
+ lg.setup(global);
+ }
+
mockApi(config);
const { app } = require("../../web/scripts/app");
await app.setup();
- return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]);
+ return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}
/**
- * @param { ReturnType["graph"] } graph
- * @param { (hasReloaded: boolean) => (Promise | void) } cb
+ * @param { ReturnType["graph"] } graph
+ * @param { (hasReloaded: boolean) => (Promise | void) } cb
*/
export async function checkBeforeAndAfterReload(graph, cb) {
await cb(false);
@@ -24,10 +31,10 @@ export async function checkBeforeAndAfterReload(graph, cb) {
}
/**
- * @param { string } name
- * @param { Record } input
+ * @param { string } name
+ * @param { Record } input
* @param { (string | string[])[] | Record } output
- * @returns { Record }
+ * @returns { Record }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
@@ -37,19 +44,19 @@ export function makeNodeDef(name, input, output = {}) {
output_name: [],
output_is_list: [],
input: {
- required: {}
+ required: {},
},
};
- for(const k in input) {
+ for (const k in input) {
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
}
- if(output instanceof Array) {
+ if (output instanceof Array) {
output = output.reduce((p, c) => {
p[c] = c;
return p;
- }, {})
+ }, {});
}
- for(const k in output) {
+ for (const k in output) {
nodeDef.output.push(output[k]);
nodeDef.output_name.push(k);
nodeDef.output_is_list.push(false);
@@ -68,4 +75,31 @@ export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(null);
expect(x).not.toEqual(undefined);
return true;
-}
\ No newline at end of file
+}
+
+/**
+ *
+ * @param { ReturnType["ez"] } ez
+ * @param { ReturnType["graph"] } graph
+ */
+export function createDefaultWorkflow(ez, graph) {
+ graph.clear();
+ const ckpt = ez.CheckpointLoaderSimple();
+
+ const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
+ const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
+
+ const empty = ez.EmptyLatentImage();
+ const sampler = ez.KSampler(
+ ckpt.outputs.MODEL,
+ pos.outputs.CONDITIONING,
+ neg.outputs.CONDITIONING,
+ empty.outputs.LATENT
+ );
+
+ const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
+ const save = ez.SaveImage(decode.outputs.IMAGE);
+ graph.arrange();
+
+ return { ckpt, pos, neg, empty, sampler, decode, save };
+}
diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js
index 17e8ac1a..dd150214 100644
--- a/tests-ui/utils/setup.js
+++ b/tests-ui/utils/setup.js
@@ -30,16 +30,20 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
}
+ const events = new EventTarget();
+ const mockApi = {
+ addEventListener: events.addEventListener.bind(events),
+ removeEventListener: events.removeEventListener.bind(events),
+ dispatchEvent: events.dispatchEvent.bind(events),
+ getSystemStats: jest.fn(),
+ getExtensions: jest.fn(() => mockExtensions),
+ getNodeDefs: jest.fn(() => mockNodeDefs),
+ init: jest.fn(),
+ apiURL: jest.fn((x) => "../../web/" + x),
+ };
jest.mock("../../web/scripts/api", () => ({
get api() {
- return {
- addEventListener: jest.fn(),
- getSystemStats: jest.fn(),
- getExtensions: jest.fn(() => mockExtensions),
- getNodeDefs: jest.fn(() => mockNodeDefs),
- init: jest.fn(),
- apiURL: jest.fn((x) => "../../web/" + x),
- };
+ return mockApi;
},
}));
}
diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js
new file mode 100644
index 00000000..450b4f5f
--- /dev/null
+++ b/web/extensions/core/groupNode.js
@@ -0,0 +1,1054 @@
+import { app } from "../../scripts/app.js";
+import { api } from "../../scripts/api.js";
+import { getWidgetType } from "../../scripts/widgets.js";
+import { mergeIfValid } from "./widgetInputs.js";
+
+const GROUP = Symbol();
+
+const Workflow = {
+ InUse: {
+ Free: 0,
+ Registered: 1,
+ InWorkflow: 2,
+ },
+ isInUseGroupNode(name) {
+ const id = `workflow/${name}`;
+ // Check if lready registered/in use in this workflow
+ if (app.graph.extra?.groupNodes?.[name]) {
+ if (app.graph._nodes.find((n) => n.type === id)) {
+ return Workflow.InUse.InWorkflow;
+ } else {
+ return Workflow.InUse.Registered;
+ }
+ }
+ return Workflow.InUse.Free;
+ },
+ storeGroupNode(name, data) {
+ let extra = app.graph.extra;
+ if (!extra) app.graph.extra = extra = {};
+ let groupNodes = extra.groupNodes;
+ if (!groupNodes) extra.groupNodes = groupNodes = {};
+ groupNodes[name] = data;
+ },
+};
+
+class GroupNodeBuilder {
+ constructor(nodes) {
+ this.nodes = nodes;
+ }
+
+ build() {
+ const name = this.getName();
+ if (!name) return;
+
+ // Sort the nodes so they are in execution order
+ // this allows for widgets to be in the correct order when reconstructing
+ this.sortNodes();
+
+ this.nodeData = this.getNodeData();
+ Workflow.storeGroupNode(name, this.nodeData);
+
+ return { name, nodeData: this.nodeData };
+ }
+
+ getName() {
+ const name = prompt("Enter group name");
+ if (!name) return;
+ const used = Workflow.isInUseGroupNode(name);
+ switch (used) {
+ case Workflow.InUse.InWorkflow:
+ alert(
+ "An in use group node with this name already exists embedded in this workflow, please remove any instances or use a new name."
+ );
+ return;
+ case Workflow.InUse.Registered:
+ if (
+ !confirm(
+ "An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?"
+ )
+ ) {
+ return;
+ }
+ break;
+ }
+ return name;
+ }
+
+ sortNodes() {
+ // Gets the builders nodes in graph execution order
+ const nodesInOrder = app.graph.computeExecutionOrder(false);
+ this.nodes = this.nodes
+ .map((node) => ({ index: nodesInOrder.indexOf(node), node }))
+ .sort((a, b) => a.index - b.index || a.node.id - b.node.id)
+ .map(({ node }) => node);
+ }
+
+ getNodeData() {
+ const storeLinkTypes = (config) => {
+ // Store link types for dynamically typed nodes e.g. reroutes
+ for (const link of config.links) {
+ const origin = app.graph.getNodeById(link[4]);
+ const type = origin.outputs[link[1]].type;
+ link.push(type);
+ }
+ };
+
+ const storeExternalLinks = (config) => {
+ // Store any external links to the group in the config so when rebuilding we add extra slots
+ config.external = [];
+ for (let i = 0; i < this.nodes.length; i++) {
+ const node = this.nodes[i];
+ if (!node.outputs?.length) continue;
+ for (let slot = 0; slot < node.outputs.length; slot++) {
+ let hasExternal = false;
+ const output = node.outputs[slot];
+ let type = output.type;
+ if (!output.links?.length) continue;
+ for (const l of output.links) {
+ const link = app.graph.links[l];
+ if (!link) continue;
+ if (type === "*") type = link.type;
+
+ if (!app.canvas.selected_nodes[link.target_id]) {
+ hasExternal = true;
+ break;
+ }
+ }
+ if (hasExternal) {
+ config.external.push([i, slot, type]);
+ }
+ }
+ }
+ };
+
+ // Use the built in copyToClipboard function to generate the node data we need
+ const backup = localStorage.getItem("litegrapheditor_clipboard");
+ try {
+ app.canvas.copyToClipboard(this.nodes);
+ const config = JSON.parse(localStorage.getItem("litegrapheditor_clipboard"));
+
+ storeLinkTypes(config);
+ storeExternalLinks(config);
+
+ return config;
+ } finally {
+ localStorage.setItem("litegrapheditor_clipboard", backup);
+ }
+ }
+}
+
+export class GroupNodeConfig {
+ constructor(name, nodeData) {
+ this.name = name;
+ this.nodeData = nodeData;
+ this.getLinks();
+
+ this.inputCount = 0;
+ this.oldToNewOutputMap = {};
+ this.newToOldOutputMap = {};
+ this.oldToNewInputMap = {};
+ this.oldToNewWidgetMap = {};
+ this.newToOldWidgetMap = {};
+ this.primitiveDefs = {};
+ this.widgetToPrimitive = {};
+ this.primitiveToWidget = {};
+ }
+
+ async registerType(source = "workflow") {
+ this.nodeDef = {
+ output: [],
+ output_name: [],
+ output_is_list: [],
+ name: source + "/" + this.name,
+ display_name: this.name,
+ category: "group nodes" + ("/" + source),
+ input: { required: {} },
+
+ [GROUP]: this,
+ };
+
+ this.inputs = [];
+ const seenInputs = {};
+ const seenOutputs = {};
+ for (let i = 0; i < this.nodeData.nodes.length; i++) {
+ const node = this.nodeData.nodes[i];
+ node.index = i;
+ this.processNode(node, seenInputs, seenOutputs);
+ }
+ await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
+ }
+
+ getLinks() {
+ this.linksFrom = {};
+ this.linksTo = {};
+ this.externalFrom = {};
+
+ // Extract links for easy lookup
+ for (const l of this.nodeData.links) {
+ const [sourceNodeId, sourceNodeSlot, targetNodeId, targetNodeSlot] = l;
+
+ // Skip links outside the copy config
+ if (sourceNodeId == null) continue;
+
+ if (!this.linksFrom[sourceNodeId]) {
+ this.linksFrom[sourceNodeId] = {};
+ }
+ this.linksFrom[sourceNodeId][sourceNodeSlot] = l;
+
+ if (!this.linksTo[targetNodeId]) {
+ this.linksTo[targetNodeId] = {};
+ }
+ this.linksTo[targetNodeId][targetNodeSlot] = l;
+ }
+
+ if (this.nodeData.external) {
+ for (const ext of this.nodeData.external) {
+ if (!this.externalFrom[ext[0]]) {
+ this.externalFrom[ext[0]] = { [ext[1]]: ext[2] };
+ } else {
+ this.externalFrom[ext[0]][ext[1]] = ext[2];
+ }
+ }
+ }
+ }
+
+ processNode(node, seenInputs, seenOutputs) {
+ const def = this.getNodeDef(node);
+ if (!def) return;
+
+ const inputs = { ...def.input?.required, ...def.input?.optional };
+
+ this.inputs.push(this.processNodeInputs(node, seenInputs, inputs));
+ if (def.output?.length) this.processNodeOutputs(node, seenOutputs, def);
+ }
+
+ getNodeDef(node) {
+ const def = globalDefs[node.type];
+ if (def) return def;
+
+ const linksFrom = this.linksFrom[node.index];
+ if (node.type === "PrimitiveNode") {
+ // Skip as its not linked
+ if (!linksFrom) return;
+
+ let type = linksFrom["0"][5];
+ if (type === "COMBO") {
+ // Use the array items
+ const source = node.outputs[0].widget.name;
+ const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type;
+ const fromType = globalDefs[fromTypeName];
+ const input = fromType.input.required[source] ?? fromType.input.optional[source];
+ type = input[0];
+ }
+
+ const def = (this.primitiveDefs[node.index] = {
+ input: {
+ required: {
+ value: [type, {}],
+ },
+ },
+ output: [type],
+ output_name: [],
+ output_is_list: [],
+ });
+ return def;
+ } else if (node.type === "Reroute") {
+ const linksTo = this.linksTo[node.index];
+ if (linksTo && linksFrom && !this.externalFrom[node.index]?.[0]) {
+ // Being used internally
+ return null;
+ }
+
+ let rerouteType = "*";
+ if (linksFrom) {
+ const [, , id, slot] = linksFrom["0"];
+ rerouteType = this.nodeData.nodes[id].inputs[slot].type;
+ } else if (linksTo) {
+ const [id, slot] = linksTo["0"];
+ rerouteType = this.nodeData.nodes[id].outputs[slot].type;
+ } else {
+ // Reroute used as a pipe
+ for (const l of this.nodeData.links) {
+ if (l[2] === node.index) {
+ rerouteType = l[5];
+ break;
+ }
+ }
+ if (rerouteType === "*") {
+ // Check for an external link
+ const t = this.externalFrom[node.index]?.[0];
+ if (t) {
+ rerouteType = t;
+ }
+ }
+ }
+
+ return {
+ input: {
+ required: {
+ [rerouteType]: [rerouteType, {}],
+ },
+ },
+ output: [rerouteType],
+ output_name: [],
+ output_is_list: [],
+ };
+ }
+
+ console.warn("Skipping virtual node " + node.type + " when building group node " + this.name);
+ }
+
+ getInputConfig(node, inputName, seenInputs, config, extra) {
+ let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
+ let prefix = "";
+ // Special handling for primitive to include the title if it is set rather than just "value"
+ if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
+ prefix = `${node.title ?? node.type} `;
+ name = `${prefix}${inputName}`;
+ if (name in seenInputs) {
+ name = `${prefix}${seenInputs[name]} ${inputName}`;
+ }
+ }
+ seenInputs[name] = (seenInputs[name] ?? 1) + 1;
+
+ if (inputName === "seed" || inputName === "noise_seed") {
+ if (!extra) extra = {};
+ extra.control_after_generate = `${prefix}control_after_generate`;
+ }
+ if (config[0] === "IMAGEUPLOAD") {
+ if (!extra) extra = {};
+ extra.widget = `${prefix}${config[1]?.widget ?? "image"}`;
+ }
+
+ if (extra) {
+ config = [config[0], { ...config[1], ...extra }];
+ }
+
+ return { name, config };
+ }
+
+ processWidgetInputs(inputs, node, inputNames, seenInputs) {
+ const slots = [];
+ const converted = new Map();
+ const widgetMap = (this.oldToNewWidgetMap[node.index] = {});
+ for (const inputName of inputNames) {
+ let widgetType = getWidgetType(inputs[inputName], inputName);
+ if (widgetType) {
+ const convertedIndex = node.inputs?.findIndex(
+ (inp) => inp.name === inputName && inp.widget?.name === inputName
+ );
+ if (convertedIndex > -1) {
+ // This widget has been converted to a widget
+ // We need to store this in the correct position so link ids line up
+ converted.set(convertedIndex, inputName);
+ widgetMap[inputName] = null;
+ } else {
+ // Normal widget
+ const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
+ this.nodeDef.input.required[name] = config;
+ widgetMap[inputName] = name;
+ this.newToOldWidgetMap[name] = { node, inputName };
+ }
+ } else {
+ // Normal input
+ slots.push(inputName);
+ }
+ }
+ return { converted, slots };
+ }
+
+ checkPrimitiveConnection(link, inputName, inputs) {
+ const sourceNode = this.nodeData.nodes[link[0]];
+ if (sourceNode.type === "PrimitiveNode") {
+ // Merge link configurations
+ const [sourceNodeId, _, targetNodeId, __] = link;
+ const primitiveDef = this.primitiveDefs[sourceNodeId];
+ const targetWidget = inputs[inputName];
+ const primitiveConfig = primitiveDef.input.required.value;
+ const output = { widget: primitiveConfig };
+ const config = mergeIfValid(output, targetWidget, false, null, primitiveConfig);
+ primitiveConfig[1] = config?.customConfig ?? inputs[inputName][1] ? { ...inputs[inputName][1] } : {};
+
+ let name = this.oldToNewWidgetMap[sourceNodeId]["value"];
+ name = name.substr(0, name.length - 6);
+ primitiveConfig[1].control_after_generate = true;
+ primitiveConfig[1].control_prefix = name;
+
+ let toPrimitive = this.widgetToPrimitive[targetNodeId];
+ if (!toPrimitive) {
+ toPrimitive = this.widgetToPrimitive[targetNodeId] = {};
+ }
+ if (toPrimitive[inputName]) {
+ toPrimitive[inputName].push(sourceNodeId);
+ }
+ toPrimitive[inputName] = sourceNodeId;
+
+ let toWidget = this.primitiveToWidget[sourceNodeId];
+ if (!toWidget) {
+ toWidget = this.primitiveToWidget[sourceNodeId] = [];
+ }
+ toWidget.push({ nodeId: targetNodeId, inputName });
+ }
+ }
+
+ processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) {
+ for (let i = 0; i < slots.length; i++) {
+ const inputName = slots[i];
+ if (linksTo[i]) {
+ this.checkPrimitiveConnection(linksTo[i], inputName, inputs);
+ // This input is linked so we can skip it
+ continue;
+ }
+
+ const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]);
+ this.nodeDef.input.required[name] = config;
+ inputMap[i] = this.inputCount++;
+ }
+ }
+
+ processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) {
+ // Add converted widgets sorted into their index order (ordered as they were converted) so link ids match up
+ const convertedSlots = [...converted.keys()].sort().map((k) => converted.get(k));
+ for (let i = 0; i < convertedSlots.length; i++) {
+ const inputName = convertedSlots[i];
+ if (linksTo[slots.length + i]) {
+ this.checkPrimitiveConnection(linksTo[slots.length + i], inputName, inputs);
+ // This input is linked so we can skip it
+ continue;
+ }
+
+ const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], {
+ defaultInput: true,
+ });
+ this.nodeDef.input.required[name] = config;
+ inputMap[slots.length + i] = this.inputCount++;
+ }
+ }
+
+ processNodeInputs(node, seenInputs, inputs) {
+ const inputMapping = [];
+
+ const inputNames = Object.keys(inputs);
+ if (!inputNames.length) return;
+
+ const { converted, slots } = this.processWidgetInputs(inputs, node, inputNames, seenInputs);
+ const linksTo = this.linksTo[node.index] ?? {};
+ const inputMap = (this.oldToNewInputMap[node.index] = {});
+ this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
+ this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
+
+ return inputMapping;
+ }
+
+ processNodeOutputs(node, seenOutputs, def) {
+ const oldToNew = (this.oldToNewOutputMap[node.index] = {});
+
+ // Add outputs
+ for (let outputId = 0; outputId < def.output.length; outputId++) {
+ const linksFrom = this.linksFrom[node.index];
+ if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) {
+ // This output is linked internally so we can skip it
+ continue;
+ }
+
+ oldToNew[outputId] = this.nodeDef.output.length;
+ this.newToOldOutputMap[this.nodeDef.output.length] = { node, slot: outputId };
+ this.nodeDef.output.push(def.output[outputId]);
+ this.nodeDef.output_is_list.push(def.output_is_list[outputId]);
+
+ let label = def.output_name?.[outputId] ?? def.output[outputId];
+ const output = node.outputs.find((o) => o.name === label);
+ if (output?.label) {
+ label = output.label;
+ }
+ let name = label;
+ if (name in seenOutputs) {
+ const prefix = `${node.title ?? node.type} `;
+ name = `${prefix}${label}`;
+ if (name in seenOutputs) {
+ name = `${prefix}${node.index} ${label}`;
+ }
+ }
+ seenOutputs[name] = 1;
+
+ this.nodeDef.output_name.push(name);
+ }
+ }
+
+ static async registerFromWorkflow(groupNodes, missingNodeTypes) {
+ for (const g in groupNodes) {
+ const groupData = groupNodes[g];
+
+ let hasMissing = false;
+ for (const n of groupData.nodes) {
+ // Find missing node types
+ if (!(n.type in LiteGraph.registered_node_types)) {
+ missingNodeTypes.push(n.type);
+ hasMissing = true;
+ }
+ }
+
+ if (hasMissing) continue;
+
+ const config = new GroupNodeConfig(g, groupData);
+ await config.registerType();
+ }
+ }
+}
+
+export class GroupNodeHandler {
+ node;
+ groupData;
+
+ constructor(node) {
+ this.node = node;
+ this.groupData = node.constructor?.nodeData?.[GROUP];
+
+ this.node.setInnerNodes = (innerNodes) => {
+ this.innerNodes = innerNodes;
+
+ for (let innerNodeIndex = 0; innerNodeIndex < this.innerNodes.length; innerNodeIndex++) {
+ const innerNode = this.innerNodes[innerNodeIndex];
+
+ for (const w of innerNode.widgets ?? []) {
+ if (w.type === "converted-widget") {
+ w.serializeValue = w.origSerializeValue;
+ }
+ }
+
+ innerNode.index = innerNodeIndex;
+ innerNode.getInputNode = (slot) => {
+ // Check if this input is internal or external
+ const externalSlot = this.groupData.oldToNewInputMap[innerNode.index]?.[slot];
+ if (externalSlot != null) {
+ return this.node.getInputNode(externalSlot);
+ }
+
+ // Internal link
+ const innerLink = this.groupData.linksTo[innerNode.index]?.[slot];
+ if (!innerLink) return null;
+
+ const inputNode = innerNodes[innerLink[0]];
+ // Primitives will already apply their values
+ if (inputNode.type === "PrimitiveNode") return null;
+
+ return inputNode;
+ };
+
+ innerNode.getInputLink = (slot) => {
+ const externalSlot = this.groupData.oldToNewInputMap[innerNode.index]?.[slot];
+ if (externalSlot != null) {
+ // The inner node is connected via the group node inputs
+ const linkId = this.node.inputs[externalSlot].link;
+ let link = app.graph.links[linkId];
+
+ // Use the outer link, but update the target to the inner node
+ link = {
+ ...link,
+ target_id: innerNode.id,
+ target_slot: +slot,
+ };
+ return link;
+ }
+
+ let link = this.groupData.linksTo[innerNode.index]?.[slot];
+ if (!link) return null;
+ // Use the inner link, but update the origin node to be inner node id
+ link = {
+ origin_id: innerNodes[link[0]].id,
+ origin_slot: link[1],
+ target_id: innerNode.id,
+ target_slot: +slot,
+ };
+ return link;
+ };
+ }
+ };
+
+ this.node.updateLink = (link) => {
+ // Replace the group node reference with the internal node
+ link = { ...link };
+ const output = this.groupData.newToOldOutputMap[link.origin_slot];
+ let innerNode = this.innerNodes[output.node.index];
+ let l;
+ while (innerNode.type === "Reroute") {
+ l = innerNode.getInputLink(0);
+ innerNode = innerNode.getInputNode(0);
+ }
+
+ link.origin_id = innerNode.id;
+ link.origin_slot = l?.origin_slot ?? output.slot;
+ return link;
+ };
+
+ this.node.getInnerNodes = () => {
+ if (!this.innerNodes) {
+ this.node.setInnerNodes(
+ this.groupData.nodeData.nodes.map((n, i) => {
+ const innerNode = LiteGraph.createNode(n.type);
+ innerNode.configure(n);
+ innerNode.id = `${this.node.id}:${i}`;
+ return innerNode;
+ })
+ );
+ }
+
+ this.updateInnerWidgets();
+
+ return this.innerNodes;
+ };
+
+ this.node.convertToNodes = () => {
+ const addInnerNodes = () => {
+ const backup = localStorage.getItem("litegrapheditor_clipboard");
+ // Clone the node data so we dont mutate it for other nodes
+ const c = { ...this.groupData.nodeData };
+ c.nodes = [...c.nodes];
+ const innerNodes = this.node.getInnerNodes();
+ let ids = [];
+ for (let i = 0; i < c.nodes.length; i++) {
+ let id = innerNodes?.[i]?.id;
+ // Use existing IDs if they are set on the inner nodes
+ if (id == null || isNaN(id)) {
+ id = undefined;
+ } else {
+ ids.push(id);
+ }
+ c.nodes[i] = { ...c.nodes[i], id };
+ }
+ localStorage.setItem("litegrapheditor_clipboard", JSON.stringify(c));
+ app.canvas.pasteFromClipboard();
+ localStorage.setItem("litegrapheditor_clipboard", backup);
+
+ const [x, y] = this.node.pos;
+ let top;
+ let left;
+ // Configure nodes with current widget data
+ const selectedIds = ids.length ? ids : Object.keys(app.canvas.selected_nodes);
+ const newNodes = [];
+ for (let i = 0; i < selectedIds.length; i++) {
+ const id = selectedIds[i];
+ const newNode = app.graph.getNodeById(id);
+ const innerNode = innerNodes[i];
+ newNodes.push(newNode);
+
+ if (left == null || newNode.pos[0] < left) {
+ left = newNode.pos[0];
+ }
+ if (top == null || newNode.pos[1] < top) {
+ top = newNode.pos[1];
+ }
+
+ const map = this.groupData.oldToNewWidgetMap[innerNode.index];
+ if (map) {
+ const widgets = Object.keys(map);
+
+ for (const oldName of widgets) {
+ const newName = map[oldName];
+ if (!newName) continue;
+
+ const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
+ if (widgetIndex === -1) continue;
+
+ // Populate the main and any linked widgets
+ if (innerNode.type === "PrimitiveNode") {
+ for (let i = 0; i < newNode.widgets.length; i++) {
+ newNode.widgets[i].value = this.node.widgets[widgetIndex + i].value;
+ }
+ } else {
+ const outerWidget = this.node.widgets[widgetIndex];
+ const newWidget = newNode.widgets.find((w) => w.name === oldName);
+ if (!newWidget) continue;
+
+ newWidget.value = outerWidget.value;
+ for (let w = 0; w < outerWidget.linkedWidgets?.length; w++) {
+ newWidget.linkedWidgets[w].value = outerWidget.linkedWidgets[w].value;
+ }
+ }
+ }
+ }
+ }
+
+ // Shift each node
+ for (const newNode of newNodes) {
+ newNode.pos = [newNode.pos[0] - (left - x), newNode.pos[1] - (top - y)];
+ }
+
+ return { newNodes, selectedIds };
+ };
+
+ const reconnectInputs = (selectedIds) => {
+ for (const innerNodeIndex in this.groupData.oldToNewInputMap) {
+ const id = selectedIds[innerNodeIndex];
+ const newNode = app.graph.getNodeById(id);
+ const map = this.groupData.oldToNewInputMap[innerNodeIndex];
+ for (const innerInputId in map) {
+ const groupSlotId = map[innerInputId];
+ if (groupSlotId == null) continue;
+ const slot = node.inputs[groupSlotId];
+ if (slot.link == null) continue;
+ const link = app.graph.links[slot.link];
+ // connect this node output to the input of another node
+ const originNode = app.graph.getNodeById(link.origin_id);
+ originNode.connect(link.origin_slot, newNode, +innerInputId);
+ }
+ }
+ };
+
+ const reconnectOutputs = () => {
+ for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
+ const output = node.outputs[groupOutputId];
+ if (!output.links) continue;
+ const links = [...output.links];
+ for (const l of links) {
+ const slot = this.groupData.newToOldOutputMap[groupOutputId];
+ const link = app.graph.links[l];
+ const targetNode = app.graph.getNodeById(link.target_id);
+ const newNode = app.graph.getNodeById(selectedIds[slot.node.index]);
+ newNode.connect(slot.slot, targetNode, link.target_slot);
+ }
+ }
+ };
+
+ const { newNodes, selectedIds } = addInnerNodes();
+ reconnectInputs(selectedIds);
+ reconnectOutputs(selectedIds);
+ app.graph.remove(this.node);
+
+ return newNodes;
+ };
+
+ const getExtraMenuOptions = this.node.getExtraMenuOptions;
+ this.node.getExtraMenuOptions = function (_, options) {
+ getExtraMenuOptions?.apply(this, arguments);
+
+ let optionIndex = options.findIndex((o) => o.content === "Outputs");
+ if (optionIndex === -1) optionIndex = options.length;
+ else optionIndex++;
+ options.splice(optionIndex, 0, null, {
+ content: "Convert to nodes",
+ callback: () => {
+ return this.convertToNodes();
+ },
+ });
+ };
+
+ // Draw custom collapse icon to identity this as a group
+ const onDrawTitleBox = this.node.onDrawTitleBox;
+ this.node.onDrawTitleBox = function (ctx, height, size, scale) {
+ onDrawTitleBox?.apply(this, arguments);
+
+ const fill = ctx.fillStyle;
+ ctx.beginPath();
+ ctx.rect(11, -height + 11, 2, 2);
+ ctx.rect(14, -height + 11, 2, 2);
+ ctx.rect(17, -height + 11, 2, 2);
+ ctx.rect(11, -height + 14, 2, 2);
+ ctx.rect(14, -height + 14, 2, 2);
+ ctx.rect(17, -height + 14, 2, 2);
+ ctx.rect(11, -height + 17, 2, 2);
+ ctx.rect(14, -height + 17, 2, 2);
+ ctx.rect(17, -height + 17, 2, 2);
+
+ ctx.fillStyle = this.boxcolor || LiteGraph.NODE_DEFAULT_BOXCOLOR;
+ ctx.fill();
+ ctx.fillStyle = fill;
+ };
+
+ // Draw progress label
+ const onDrawForeground = node.onDrawForeground;
+ const groupData = this.groupData.nodeData;
+ node.onDrawForeground = function (ctx) {
+ const r = onDrawForeground?.apply?.(this, arguments);
+ if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) {
+ const n = groupData.nodes[this.runningInternalNodeId];
+ const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`;
+ ctx.save();
+ ctx.font = "12px sans-serif";
+ const sz = ctx.measureText(message);
+ ctx.fillStyle = node.boxcolor || LiteGraph.NODE_DEFAULT_BOXCOLOR;
+ ctx.beginPath();
+ ctx.roundRect(0, -LiteGraph.NODE_TITLE_HEIGHT - 20, sz.width + 12, 20, 5);
+ ctx.fill();
+
+ ctx.fillStyle = "#fff";
+ ctx.fillText(message, 6, -LiteGraph.NODE_TITLE_HEIGHT - 6);
+ ctx.restore();
+ }
+ };
+
+ // Flag this node as needing to be reset
+ const onExecutionStart = this.node.onExecutionStart;
+ this.node.onExecutionStart = function () {
+ this.resetExecution = true;
+ return onExecutionStart?.apply(this, arguments);
+ };
+
+ function handleEvent(type, getId, getEvent) {
+ const handler = ({ detail }) => {
+ const id = getId(detail);
+ if (!id) return;
+ const node = app.graph.getNodeById(id);
+ if (node) return;
+
+ const innerNodeIndex = this.innerNodes?.findIndex((n) => n.id == id);
+ if (innerNodeIndex > -1) {
+ this.node.runningInternalNodeId = innerNodeIndex;
+ api.dispatchEvent(new CustomEvent(type, { detail: getEvent(detail, this.node.id + "", this.node) }));
+ }
+ };
+ api.addEventListener(type, handler);
+ return handler;
+ }
+
+ const executing = handleEvent.call(
+ this,
+ "executing",
+ (d) => d,
+ (d, id, node) => id
+ );
+
+ const executed = handleEvent.call(
+ this,
+ "executed",
+ (d) => d?.node,
+ (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution })
+ );
+
+ const onRemoved = node.onRemoved;
+ this.node.onRemoved = function () {
+ onRemoved?.apply(this, arguments);
+ api.removeEventListener("executing", executing);
+ api.removeEventListener("executed", executed);
+ };
+ }
+
+ updateInnerWidgets() {
+ for (const newWidgetName in this.groupData.newToOldWidgetMap) {
+ const newWidget = this.node.widgets.find((w) => w.name === newWidgetName);
+ if (!newWidget) continue;
+
+ const newValue = newWidget.value;
+ const old = this.groupData.newToOldWidgetMap[newWidgetName];
+ let innerNode = this.innerNodes[old.node.index];
+
+ if (innerNode.type === "PrimitiveNode") {
+ innerNode.primitiveValue = newValue;
+ const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
+ for (const linked of primitiveLinked) {
+ const node = this.innerNodes[linked.nodeId];
+ const widget = node.widgets.find((w) => w.name === linked.inputName);
+
+ if (widget) {
+ widget.value = newValue;
+ }
+ }
+ continue;
+ }
+
+ const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
+ if (widget) {
+ widget.value = newValue;
+ }
+ }
+ }
+
+ populatePrimitive(node, nodeId, oldName, i, linkedShift) {
+ // Converted widget, populate primitive if linked
+ const primitiveId = this.groupData.widgetToPrimitive[nodeId]?.[oldName];
+ if (primitiveId == null) return;
+ const targetWidgetName = this.groupData.oldToNewWidgetMap[primitiveId]["value"];
+ const targetWidgetIndex = this.node.widgets.findIndex((w) => w.name === targetWidgetName);
+ if (targetWidgetIndex > -1) {
+ const primitiveNode = this.innerNodes[primitiveId];
+ let len = primitiveNode.widgets.length;
+ if (len - 1 !== this.node.widgets[targetWidgetIndex].linkedWidgets?.length) {
+ // Fallback handling for if some reason the primitive has a different number of widgets
+ // we dont want to overwrite random widgets, better to leave blank
+ len = 1;
+ }
+ for (let i = 0; i < len; i++) {
+ this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
+ }
+ }
+ }
+
+ populateWidgets() {
+ for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
+ const node = this.groupData.nodeData.nodes[nodeId];
+
+ if (!node.widgets_values?.length) continue;
+
+ const map = this.groupData.oldToNewWidgetMap[nodeId];
+ const widgets = Object.keys(map);
+
+ let linkedShift = 0;
+ for (let i = 0; i < widgets.length; i++) {
+ const oldName = widgets[i];
+ const newName = map[oldName];
+ const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
+ const mainWidget = this.node.widgets[widgetIndex];
+ if (!newName) {
+ // New name will be null if its a converted widget
+ this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
+
+ // Find the inner widget and shift by the number of linked widgets as they will have been removed too
+ const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
+ linkedShift += innerWidget.linkedWidgets?.length ?? 0;
+ continue;
+ }
+
+ if (widgetIndex === -1) {
+ continue;
+ }
+
+ // Populate the main and any linked widget
+ mainWidget.value = node.widgets_values[i + linkedShift];
+ for (let w = 0; w < mainWidget.linkedWidgets?.length; w++) {
+ this.node.widgets[widgetIndex + w + 1].value = node.widgets_values[i + ++linkedShift];
+ }
+ }
+ }
+ }
+
+ replaceNodes(nodes) {
+ let top;
+ let left;
+
+ for (let i = 0; i < nodes.length; i++) {
+ const node = nodes[i];
+ if (left == null || node.pos[0] < left) {
+ left = node.pos[0];
+ }
+ if (top == null || node.pos[1] < top) {
+ top = node.pos[1];
+ }
+
+ this.linkOutputs(node, i);
+ app.graph.remove(node);
+ }
+
+ this.linkInputs();
+ this.node.pos = [left, top];
+ }
+
+ linkOutputs(originalNode, nodeId) {
+ if (!originalNode.outputs) return;
+
+ for (const output of originalNode.outputs) {
+ if (!output.links) continue;
+ // Clone the links as they'll be changed if we reconnect
+ const links = [...output.links];
+ for (const l of links) {
+ const link = app.graph.links[l];
+ if (!link) continue;
+
+ const targetNode = app.graph.getNodeById(link.target_id);
+ const newSlot = this.groupData.oldToNewOutputMap[nodeId]?.[link.origin_slot];
+ if (newSlot != null) {
+ this.node.connect(newSlot, targetNode, link.target_slot);
+ }
+ }
+ }
+ }
+
+ linkInputs() {
+ for (const link of this.groupData.nodeData.links ?? []) {
+ const [, originSlot, targetId, targetSlot, actualOriginId] = link;
+ const originNode = app.graph.getNodeById(actualOriginId);
+ if (!originNode) continue; // this node is in the group
+ originNode.connect(originSlot, this.node.id, this.groupData.oldToNewInputMap[targetId][targetSlot]);
+ }
+ }
+
+ static getGroupData(node) {
+ return node.constructor?.nodeData?.[GROUP];
+ }
+
+ static isGroupNode(node) {
+ return !!node.constructor?.nodeData?.[GROUP];
+ }
+
+ static async fromNodes(nodes) {
+ // Process the nodes into the stored workflow group node data
+ const builder = new GroupNodeBuilder(nodes);
+ const res = builder.build();
+ if (!res) return;
+
+ const { name, nodeData } = res;
+
+ // Convert this data into a LG node definition and register it
+ const config = new GroupNodeConfig(name, nodeData);
+ await config.registerType();
+
+ const groupNode = LiteGraph.createNode(`workflow/${name}`);
+ // Reuse the existing nodes for this instance
+ groupNode.setInnerNodes(builder.nodes);
+ groupNode[GROUP].populateWidgets();
+ app.graph.add(groupNode);
+
+ // Remove all converted nodes and relink them
+ groupNode[GROUP].replaceNodes(builder.nodes);
+ return groupNode;
+ }
+}
+
+function addConvertToGroupOptions() {
+ function addOption(options, index) {
+ const selected = Object.values(app.canvas.selected_nodes ?? {});
+ const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n));
+ options.splice(index + 1, null, {
+ content: `Convert to Group Node`,
+ disabled,
+ callback: async () => {
+ return await GroupNodeHandler.fromNodes(selected);
+ },
+ });
+ }
+
+ // Add to canvas
+ const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions;
+ LGraphCanvas.prototype.getCanvasMenuOptions = function () {
+ const options = getCanvasMenuOptions.apply(this, arguments);
+ const index = options.findIndex((o) => o?.content === "Add Group") + 1 || opts.length;
+ addOption(options, index);
+ return options;
+ };
+
+ // Add to nodes
+ const getNodeMenuOptions = LGraphCanvas.prototype.getNodeMenuOptions;
+ LGraphCanvas.prototype.getNodeMenuOptions = function (node) {
+ const options = getNodeMenuOptions.apply(this, arguments);
+ if (!GroupNodeHandler.isGroupNode(node)) {
+ const index = options.findIndex((o) => o?.content === "Outputs") + 1 || opts.length - 1;
+ addOption(options, index);
+ }
+ return options;
+ };
+}
+
+const id = "Comfy.GroupNode";
+let globalDefs;
+const ext = {
+ name: id,
+ setup() {
+ addConvertToGroupOptions();
+ },
+ async beforeConfigureGraph(graphData, missingNodeTypes) {
+ const nodes = graphData?.extra?.groupNodes;
+ if (nodes) {
+ await GroupNodeConfig.registerFromWorkflow(nodes, missingNodeTypes);
+ }
+ },
+ addCustomNodeDefs(defs) {
+ // Store this so we can mutate it later with group nodes
+ globalDefs = defs;
+ },
+ nodeCreated(node) {
+ if (GroupNodeHandler.isGroupNode(node)) {
+ node[GROUP] = new GroupNodeHandler(node);
+ }
+ },
+};
+
+app.registerExtension(ext);
diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js
index b6479f45..2d482174 100644
--- a/web/extensions/core/nodeTemplates.js
+++ b/web/extensions/core/nodeTemplates.js
@@ -1,5 +1,6 @@
import { app } from "../../scripts/app.js";
import { ComfyDialog, $el } from "../../scripts/ui.js";
+import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Adds the ability to save and add multiple nodes as a template
// To save:
@@ -34,7 +35,7 @@ class ManageTemplates extends ComfyDialog {
type: "file",
accept: ".json",
multiple: true,
- style: {display: "none"},
+ style: { display: "none" },
parent: document.body,
onchange: () => this.importAll(),
});
@@ -109,13 +110,13 @@ class ManageTemplates extends ComfyDialog {
return;
}
- const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string
- const blob = new Blob([json], {type: "application/json"});
+ const json = JSON.stringify({ templates: this.templates }, null, 2); // convert the data to a JSON string
+ const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "node_templates.json",
- style: {display: "none"},
+ style: { display: "none" },
parent: document.body,
});
a.click();
@@ -291,11 +292,11 @@ app.registerExtension({
setup() {
const manage = new ManageTemplates();
- const clipboardAction = (cb) => {
+ const clipboardAction = async (cb) => {
// We use the clipboard functions but dont want to overwrite the current user clipboard
// Restore it after we've run our callback
const old = localStorage.getItem("litegrapheditor_clipboard");
- cb();
+ await cb();
localStorage.setItem("litegrapheditor_clipboard", old);
};
@@ -309,13 +310,31 @@ app.registerExtension({
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
const name = prompt("Enter name");
- if (!name || !name.trim()) return;
+ if (!name?.trim()) return;
clipboardAction(() => {
app.canvas.copyToClipboard();
+ let data = localStorage.getItem("litegrapheditor_clipboard");
+ data = JSON.parse(data);
+ const nodeIds = Object.keys(app.canvas.selected_nodes);
+ for (let i = 0; i < nodeIds.length; i++) {
+ const node = app.graph.getNodeById(nodeIds[i]);
+ const nodeData = node?.constructor.nodeData;
+
+ let groupData = GroupNodeHandler.getGroupData(node);
+ if (groupData) {
+ groupData = groupData.nodeData;
+ if (!data.groupNodes) {
+ data.groupNodes = {};
+ }
+ data.groupNodes[nodeData.name] = groupData;
+ data.nodes[i].type = nodeData.name;
+ }
+ }
+
manage.templates.push({
name,
- data: localStorage.getItem("litegrapheditor_clipboard"),
+ data: JSON.stringify(data),
});
manage.store();
});
@@ -323,15 +342,19 @@ app.registerExtension({
});
// Map each template to a menu item
- const subItems = manage.templates.map((t) => ({
- content: t.name,
- callback: () => {
- clipboardAction(() => {
- localStorage.setItem("litegrapheditor_clipboard", t.data);
- app.canvas.pasteFromClipboard();
- });
- },
- }));
+ const subItems = manage.templates.map((t) => {
+ return {
+ content: t.name,
+ callback: () => {
+ clipboardAction(async () => {
+ const data = JSON.parse(t.data);
+ await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
+ localStorage.setItem("litegrapheditor_clipboard", t.data);
+ app.canvas.pasteFromClipboard();
+ });
+ },
+ };
+ });
subItems.push(null, {
content: "Manage",
diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js
index 5c8fbc9b..b6fa411f 100644
--- a/web/extensions/core/widgetInputs.js
+++ b/web/extensions/core/widgetInputs.js
@@ -121,6 +121,110 @@ function isValidCombo(combo, obj) {
return true;
}
+export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
+ if (!config1) {
+ config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
+ }
+
+ if (config1[0] instanceof Array) {
+ if (!isValidCombo(config1[0], config2[0])) return false;
+ } else if (config1[0] !== config2[0]) {
+ // Types dont match
+ console.log(`connection rejected: types dont match`, config1[0], config2[0]);
+ return false;
+ }
+
+ const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
+
+ let customConfig;
+ const getCustomConfig = () => {
+ if (!customConfig) {
+ if (typeof structuredClone === "undefined") {
+ customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
+ } else {
+ customConfig = structuredClone(config1[1] ?? {});
+ }
+ }
+ return customConfig;
+ };
+
+ const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
+ for (const k of keys.values()) {
+ if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
+ let v1 = config1[1][k];
+ let v2 = config2[1]?.[k];
+
+ if (v1 === v2 || (!v1 && !v2)) continue;
+
+ if (isNumber) {
+ if (k === "min") {
+ const theirMax = config2[1]?.["max"];
+ if (theirMax != null && v1 > theirMax) {
+ console.log("connection rejected: min > max", v1, theirMax);
+ return false;
+ }
+ getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
+ continue;
+ } else if (k === "max") {
+ const theirMin = config2[1]?.["min"];
+ if (theirMin != null && v1 < theirMin) {
+ console.log("connection rejected: max < min", v1, theirMin);
+ return false;
+ }
+ getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
+ continue;
+ } else if (k === "step") {
+ let step;
+ if (v1 == null) {
+ // No current step
+ step = v2;
+ } else if (v2 == null) {
+ // No new step
+ step = v1;
+ } else {
+ if (v1 < v2) {
+ // Ensure v1 is larger for the mod
+ const a = v2;
+ v2 = v1;
+ v1 = a;
+ }
+ if (v1 % v2) {
+ console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
+ return false;
+ }
+
+ step = v1;
+ }
+
+ getCustomConfig()[k] = step;
+ continue;
+ }
+ }
+
+ console.log(`connection rejected: config ${k} values dont match`, v1, v2);
+ return false;
+ }
+ }
+
+ if (customConfig || forceUpdate) {
+ if (customConfig) {
+ output.widget[CONFIG] = [config1[0], customConfig];
+ }
+
+ const widget = recreateWidget?.call(this);
+ // When deleting a node this can be null
+ if (widget) {
+ const min = widget.options.min;
+ const max = widget.options.max;
+ if (min != null && widget.value < min) widget.value = min;
+ if (max != null && widget.value > max) widget.value = max;
+ widget.callback(widget.value);
+ }
+ }
+
+ return { customConfig };
+}
+
app.registerExtension({
name: "Comfy.WidgetInputs",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
@@ -308,7 +412,7 @@ app.registerExtension({
this.isVirtualNode = true;
}
- applyToGraph() {
+ applyToGraph(extraLinks = []) {
if (!this.outputs[0].links?.length) return;
function get_links(node) {
@@ -325,10 +429,9 @@ app.registerExtension({
return links;
}
- let links = get_links(this);
+ let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
// For each output link copy our value over the original widget value
- for (const l of links) {
- const linkInfo = app.graph.links[l];
+ for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
const widgetName = input.widget.name;
@@ -405,7 +508,12 @@ app.registerExtension({
}
if (this.outputs[slot].links?.length) {
- return this.#isValidConnection(input);
+ const valid = this.#isValidConnection(input);
+ if (valid) {
+ // On connect of additional outputs, copy our value to their widget
+ this.applyToGraph([{ target_id: target_node.id, target_slot }]);
+ }
+ return valid;
}
}
@@ -462,12 +570,12 @@ app.registerExtension({
}
}
- if (widget.type === "number" || widget.type === "combo") {
+ if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
- addValueControlWidgets(this, widget, control_value);
+ addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
@@ -507,6 +615,7 @@ app.registerExtension({
this.#removeWidgets();
this.#onFirstConnection(true);
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
+ return this.widgets[0];
}
#mergeWidgetConfig() {
@@ -547,108 +656,8 @@ app.registerExtension({
#isValidConnection(input, forceUpdate) {
// Only allow connections where the configs match
const output = this.outputs[0];
- const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
const config2 = input.widget[GET_CONFIG]();
-
- if (config1[0] instanceof Array) {
- if (!isValidCombo(config1[0], config2[0])) return false;
- } else if (config1[0] !== config2[0]) {
- // Types dont match
- console.log(`connection rejected: types dont match`, config1[0], config2[0]);
- return false;
- }
-
- const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
-
- let customConfig;
- const getCustomConfig = () => {
- if (!customConfig) {
- if (typeof structuredClone === "undefined") {
- customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
- } else {
- customConfig = structuredClone(config1[1] ?? {});
- }
- }
- return customConfig;
- };
-
- const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
- for (const k of keys.values()) {
- if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
- let v1 = config1[1][k];
- let v2 = config2[1][k];
-
- if (v1 === v2 || (!v1 && !v2)) continue;
-
- if (isNumber) {
- if (k === "min") {
- const theirMax = config2[1]["max"];
- if (theirMax != null && v1 > theirMax) {
- console.log("connection rejected: min > max", v1, theirMax);
- return false;
- }
- getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
- continue;
- } else if (k === "max") {
- const theirMin = config2[1]["min"];
- if (theirMin != null && v1 < theirMin) {
- console.log("connection rejected: max < min", v1, theirMin);
- return false;
- }
- getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
- continue;
- } else if (k === "step") {
- let step;
- if (v1 == null) {
- // No current step
- step = v2;
- } else if (v2 == null) {
- // No new step
- step = v1;
- } else {
- if (v1 < v2) {
- // Ensure v1 is larger for the mod
- const a = v2;
- v2 = v1;
- v1 = a;
- }
- if (v1 % v2) {
- console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
- return false;
- }
-
- step = v1;
- }
-
- getCustomConfig()[k] = step;
- continue;
- }
- }
-
- console.log(`connection rejected: config ${k} values dont match`, v1, v2);
- return false;
- }
- }
-
- if (customConfig || forceUpdate) {
- if (customConfig) {
- output.widget[CONFIG] = [config1[0], customConfig];
- }
-
- this.#recreateWidget();
-
- const widget = this.widgets[0];
- // When deleting a node this can be null
- if (widget) {
- const min = widget.options.min;
- const max = widget.options.max;
- if (min != null && widget.value < min) widget.value = min;
- if (max != null && widget.value > max) widget.value = max;
- widget.callback(widget.value);
- }
- }
-
- return true;
+ return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
}
#removeWidgets() {
diff --git a/web/scripts/app.js b/web/scripts/app.js
index cd20c40f..e9cfb277 100644
--- a/web/scripts/app.js
+++ b/web/scripts/app.js
@@ -1,5 +1,5 @@
import { ComfyLogging } from "./logging.js";
-import { ComfyWidgets } from "./widgets.js";
+import { ComfyWidgets, getWidgetType } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
@@ -779,7 +779,7 @@ export class ComfyApp {
* Adds a handler on paste that extracts and loads images or workflows from pasted JSON data
*/
#addPasteHandler() {
- document.addEventListener("paste", (e) => {
+ document.addEventListener("paste", async (e) => {
// ctrl+shift+v is used to paste nodes with connections
// this is handled by litegraph
if(this.shiftDown) return;
@@ -827,7 +827,7 @@ export class ComfyApp {
}
if (workflow && workflow.version && workflow.nodes && workflow.extra) {
- this.loadGraphData(workflow);
+ await this.loadGraphData(workflow);
}
else {
if (e.target.type === "text" || e.target.type === "textarea") {
@@ -1177,7 +1177,19 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
- this.nodeOutputs[detail.node] = detail.output;
+ const output = this.nodeOutputs[detail.node];
+ if (detail.merge && output) {
+ for (const k in detail.output ?? {}) {
+ const v = output[k];
+ if (v instanceof Array) {
+ output[k] = v.concat(detail.output[k]);
+ } else {
+ output[k] = detail.output[k];
+ }
+ }
+ } else {
+ this.nodeOutputs[detail.node] = detail.output;
+ }
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
@@ -1292,6 +1304,7 @@ export class ComfyApp {
this.#addProcessMouseHandler();
this.#addProcessKeyHandler();
this.#addConfigureHandler();
+ this.#addApiUpdateHandlers();
this.graph = new LGraph();
@@ -1328,7 +1341,7 @@ export class ComfyApp {
const json = localStorage.getItem("workflow");
if (json) {
const workflow = JSON.parse(json);
- this.loadGraphData(workflow);
+ await this.loadGraphData(workflow);
restored = true;
}
} catch (err) {
@@ -1337,7 +1350,7 @@ export class ComfyApp {
// We failed to restore a workflow so load the default
if (!restored) {
- this.loadGraphData();
+ await this.loadGraphData();
}
// Save current workflow automatically
@@ -1345,7 +1358,6 @@ export class ComfyApp {
this.#addDrawNodeHandler();
this.#addDrawGroupsHandler();
- this.#addApiUpdateHandlers();
this.#addDropHandler();
this.#addCopyHandler();
this.#addPasteHandler();
@@ -1365,11 +1377,81 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("registerCustomNodes");
}
+ async registerNodeDef(nodeId, nodeData) {
+ const self = this;
+ const node = Object.assign(
+ function ComfyNode() {
+ var inputs = nodeData["input"]["required"];
+ if (nodeData["input"]["optional"] != undefined) {
+ inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]);
+ }
+ const config = { minWidth: 1, minHeight: 1 };
+ for (const inputName in inputs) {
+ const inputData = inputs[inputName];
+ const type = inputData[0];
+
+ let widgetCreated = true;
+ const widgetType = getWidgetType(inputData, inputName);
+ if(widgetType) {
+ if(widgetType === "COMBO") {
+ Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {});
+ } else {
+ Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {});
+ }
+ } else {
+ // Node connection inputs
+ this.addInput(inputName, type);
+ widgetCreated = false;
+ }
+
+ if(widgetCreated && inputData[1]?.forceInput && config?.widget) {
+ if (!config.widget.options) config.widget.options = {};
+ config.widget.options.forceInput = inputData[1].forceInput;
+ }
+ if(widgetCreated && inputData[1]?.defaultInput && config?.widget) {
+ if (!config.widget.options) config.widget.options = {};
+ config.widget.options.defaultInput = inputData[1].defaultInput;
+ }
+ }
+
+ for (const o in nodeData["output"]) {
+ let output = nodeData["output"][o];
+ if(output instanceof Array) output = "COMBO";
+ const outputName = nodeData["output_name"][o] || output;
+ const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
+ this.addOutput(outputName, output, { shape: outputShape });
+ }
+
+ const s = this.computeSize();
+ s[0] = Math.max(config.minWidth, s[0] * 1.5);
+ s[1] = Math.max(config.minHeight, s[1]);
+ this.size = s;
+ this.serialize_widgets = true;
+
+ app.#invokeExtensionsAsync("nodeCreated", this);
+ },
+ {
+ title: nodeData.display_name || nodeData.name,
+ comfyClass: nodeData.name,
+ nodeData
+ }
+ );
+ node.prototype.comfyClass = nodeData.name;
+
+ this.#addNodeContextMenuHandler(node);
+ this.#addDrawBackgroundHandler(node, app);
+ this.#addNodeKeyHandler(node);
+
+ await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
+ LiteGraph.registerNodeType(nodeId, node);
+ node.category = nodeData.category;
+ }
+
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets
- const widgets = Object.assign(
+ this.widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean)
@@ -1377,75 +1459,7 @@ export class ComfyApp {
// Register a node for each definition
for (const nodeId in defs) {
- const nodeData = defs[nodeId];
- const node = Object.assign(
- function ComfyNode() {
- var inputs = nodeData["input"]["required"];
- if (nodeData["input"]["optional"] != undefined){
- inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
- }
- const config = { minWidth: 1, minHeight: 1 };
- for (const inputName in inputs) {
- const inputData = inputs[inputName];
- const type = inputData[0];
-
- let widgetCreated = true;
- if (Array.isArray(type)) {
- // Enums
- Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {});
- } else if (`${type}:${inputName}` in widgets) {
- // Support custom widgets by Type:Name
- Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
- } else if (type in widgets) {
- // Standard type widgets
- Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
- } else {
- // Node connection inputs
- this.addInput(inputName, type);
- widgetCreated = false;
- }
-
- if(widgetCreated && inputData[1]?.forceInput && config?.widget) {
- if (!config.widget.options) config.widget.options = {};
- config.widget.options.forceInput = inputData[1].forceInput;
- }
- if(widgetCreated && inputData[1]?.defaultInput && config?.widget) {
- if (!config.widget.options) config.widget.options = {};
- config.widget.options.defaultInput = inputData[1].defaultInput;
- }
- }
-
- for (const o in nodeData["output"]) {
- let output = nodeData["output"][o];
- if(output instanceof Array) output = "COMBO";
- const outputName = nodeData["output_name"][o] || output;
- const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
- this.addOutput(outputName, output, { shape: outputShape });
- }
-
- const s = this.computeSize();
- s[0] = Math.max(config.minWidth, s[0] * 1.5);
- s[1] = Math.max(config.minHeight, s[1]);
- this.size = s;
- this.serialize_widgets = true;
-
- app.#invokeExtensionsAsync("nodeCreated", this);
- },
- {
- title: nodeData.display_name || nodeData.name,
- comfyClass: nodeData.name,
- nodeData
- }
- );
- node.prototype.comfyClass = nodeData.name;
-
- this.#addNodeContextMenuHandler(node);
- this.#addDrawBackgroundHandler(node, app);
- this.#addNodeKeyHandler(node);
-
- await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
- LiteGraph.registerNodeType(nodeId, node);
- node.category = nodeData.category;
+ this.registerNodeDef(nodeId, defs[nodeId]);
}
}
@@ -1488,9 +1502,14 @@ export class ComfyApp {
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
this.ui.dialog.show(
- `When loading the graph, the following node types were not found: ${Array.from(new Set(missingNodeTypes)).map(
- (t) => `- ${t}
`
- ).join("")}
${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}`
+ $el("div", [
+ $el("span", { textContent: "When loading the graph, the following node types were not found: " }),
+ $el(
+ "ul",
+ Array.from(new Set(missingNodeTypes)).map((t) => $el("li", { textContent: t }))
+ ),
+ ...(hasAddedNodes ? [$el("span", { textContent: "Nodes that have failed to load will show as red on the graph." })] : []),
+ ])
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
@@ -1501,7 +1520,7 @@ export class ComfyApp {
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
*/
- loadGraphData(graphData) {
+ async loadGraphData(graphData) {
this.clean();
let reset_invalid_values = false;
@@ -1519,6 +1538,7 @@ export class ComfyApp {
}
const missingNodeTypes = [];
+ await this.#invokeExtensionsAsync("beforeConfigureGraph", graphData, missingNodeTypes);
for (let n of graphData.nodes) {
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
@@ -1527,8 +1547,8 @@ export class ComfyApp {
// Find missing node types
if (!(n.type in LiteGraph.registered_node_types)) {
- n.type = sanitizeNodeName(n.type);
missingNodeTypes.push(n.type);
+ n.type = sanitizeNodeName(n.type);
}
}
@@ -1627,92 +1647,98 @@ export class ComfyApp {
* @returns The workflow and node links
*/
async graphToPrompt() {
- for (const node of this.graph.computeExecutionOrder(false)) {
- if (node.isVirtualNode) {
- // Don't serialize frontend only nodes but let them make changes
- if (node.applyToGraph) {
- node.applyToGraph();
+ for (const outerNode of this.graph.computeExecutionOrder(false)) {
+ const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
+ for (const node of innerNodes) {
+ if (node.isVirtualNode) {
+ // Don't serialize frontend only nodes but let them make changes
+ if (node.applyToGraph) {
+ node.applyToGraph();
+ }
}
- continue;
}
}
const workflow = this.graph.serialize();
const output = {};
// Process nodes in order of execution
- for (const node of this.graph.computeExecutionOrder(false)) {
- const n = workflow.nodes.find((n) => n.id === node.id);
+ for (const outerNode of this.graph.computeExecutionOrder(false)) {
+ const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
+ for (const node of innerNodes) {
+ if (node.isVirtualNode) {
+ continue;
+ }
- if (node.isVirtualNode) {
- continue;
- }
+ if (node.mode === 2 || node.mode === 4) {
+ // Don't serialize muted nodes
+ continue;
+ }
- if (node.mode === 2 || node.mode === 4) {
- // Don't serialize muted nodes
- continue;
- }
+ const inputs = {};
+ const widgets = node.widgets;
- const inputs = {};
- const widgets = node.widgets;
-
- // Store all widget values
- if (widgets) {
- for (const i in widgets) {
- const widget = widgets[i];
- if (!widget.options || widget.options.serialize !== false) {
- inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
+ // Store all widget values
+ if (widgets) {
+ for (const i in widgets) {
+ const widget = widgets[i];
+ if (!widget.options || widget.options.serialize !== false) {
+ inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(node, i) : widget.value;
+ }
}
}
- }
- // Store all node links
- for (let i in node.inputs) {
- let parent = node.getInputNode(i);
- if (parent) {
- let link = node.getInputLink(i);
- while (parent.mode === 4 || parent.isVirtualNode) {
- let found = false;
- if (parent.isVirtualNode) {
- link = parent.getInputLink(link.origin_slot);
- if (link) {
- parent = parent.getInputNode(link.target_slot);
- if (parent) {
- found = true;
- }
- }
- } else if (link && parent.mode === 4) {
- let all_inputs = [link.origin_slot];
- if (parent.inputs) {
- all_inputs = all_inputs.concat(Object.keys(parent.inputs))
- for (let parent_input in all_inputs) {
- parent_input = all_inputs[parent_input];
- if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
- link = parent.getInputLink(parent_input);
- if (link) {
- parent = parent.getInputNode(parent_input);
- }
+ // Store all node links
+ for (let i in node.inputs) {
+ let parent = node.getInputNode(i);
+ if (parent) {
+ let link = node.getInputLink(i);
+ while (parent.mode === 4 || parent.isVirtualNode) {
+ let found = false;
+ if (parent.isVirtualNode) {
+ link = parent.getInputLink(link.origin_slot);
+ if (link) {
+ parent = parent.getInputNode(link.target_slot);
+ if (parent) {
found = true;
- break;
+ }
+ }
+ } else if (link && parent.mode === 4) {
+ let all_inputs = [link.origin_slot];
+ if (parent.inputs) {
+ all_inputs = all_inputs.concat(Object.keys(parent.inputs))
+ for (let parent_input in all_inputs) {
+ parent_input = all_inputs[parent_input];
+ if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
+ link = parent.getInputLink(parent_input);
+ if (link) {
+ parent = parent.getInputNode(parent_input);
+ }
+ found = true;
+ break;
+ }
}
}
}
+
+ if (!found) {
+ break;
+ }
}
- if (!found) {
- break;
+ if (link) {
+ if (parent?.updateLink) {
+ link = parent.updateLink(link);
+ }
+ inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
}
-
- if (link) {
- inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
- }
}
- }
- output[String(node.id)] = {
- inputs,
- class_type: node.comfyClass,
- };
+ output[String(node.id)] = {
+ inputs,
+ class_type: node.comfyClass,
+ };
+ }
}
// Remove inputs connected to removed nodes
@@ -1832,7 +1858,7 @@ export class ComfyApp {
const pngInfo = await getPngMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
- this.loadGraphData(JSON.parse(pngInfo.workflow));
+ await this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.parameters) {
importA1111(this.graph, pngInfo.parameters);
}
@@ -1848,21 +1874,21 @@ export class ComfyApp {
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
- reader.onload = () => {
+ reader.onload = async () => {
const jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent);
} else {
- this.loadGraphData(jsonContent);
+ await this.loadGraphData(jsonContent);
}
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
- this.loadGraphData(JSON.parse(info.workflow));
+ await this.loadGraphData(JSON.parse(info.workflow));
}
}
}
diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js
index 07da591c..37d26f3c 100644
--- a/web/scripts/domWidget.js
+++ b/web/scripts/domWidget.js
@@ -44,7 +44,7 @@ function getClipPath(node, element, elRect) {
}
function computeSize(size) {
- if (this.widgets?.[0].last_y == null) return;
+ if (this.widgets?.[0]?.last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
@@ -195,7 +195,6 @@ export function addDomClippingSetting() {
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
- console.log("enableDomClipping", enableDomClipping);
enableDomClipping = !!value;
},
});
diff --git a/web/scripts/ui.js b/web/scripts/ui.js
index 8a58d30b..ebaf86fe 100644
--- a/web/scripts/ui.js
+++ b/web/scripts/ui.js
@@ -462,8 +462,8 @@ class ComfyList {
return $el("div", {textContent: item.prompt[0] + ": "}, [
$el("button", {
textContent: "Load",
- onclick: () => {
- app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
+ onclick: async () => {
+ await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) {
app.nodeOutputs = item.outputs;
}
@@ -784,9 +784,9 @@ export class ComfyUI {
}
}),
$el("button", {
- id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
+ id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
- app.loadGraphData()
+ await app.loadGraphData()
}
}
}),
diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js
index fbc1d0fc..de5877e5 100644
--- a/web/scripts/widgets.js
+++ b/web/scripts/widgets.js
@@ -23,29 +23,73 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
}
-export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
- const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, {
+export function getWidgetType(inputData, inputName) {
+ const type = inputData[0];
+
+ if (Array.isArray(type)) {
+ return "COMBO";
+ } else if (`${type}:${inputName}` in ComfyWidgets) {
+ return `${type}:${inputName}`;
+ } else if (type in ComfyWidgets) {
+ return type;
+ } else {
+ return null;
+ }
+}
+
+export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
+ let name = inputData[1]?.control_after_generate;
+ if(typeof name !== "string") {
+ name = widgetName;
+ }
+ const widgets = addValueControlWidgets(node, targetWidget, defaultValue, {
addFilterList: false,
- });
+ controlAfterGenerateName: name
+ }, inputData);
return widgets[0];
}
-export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) {
+export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) {
+ if (!defaultValue) defaultValue = "randomize";
if (!options) options = {};
-
+
+ const getName = (defaultName, optionName) => {
+ let name = defaultName;
+ if (options[optionName]) {
+ name = options[optionName];
+ } else if (typeof inputData?.[1]?.[defaultName] === "string") {
+ name = inputData?.[1]?.[defaultName];
+ } else if (inputData?.[1]?.control_prefix) {
+ name = inputData?.[1]?.control_prefix + " " + name
+ }
+ return name;
+ }
+
const widgets = [];
- const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
- values: ["fixed", "increment", "decrement", "randomize"],
- serialize: false, // Don't include this in prompt.
- });
+ const valueControl = node.addWidget(
+ "combo",
+ getName("control_after_generate", "controlAfterGenerateName"),
+ defaultValue,
+ function () {},
+ {
+ values: ["fixed", "increment", "decrement", "randomize"],
+ serialize: false, // Don't include this in prompt.
+ }
+ );
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
- comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, {
- serialize: false, // Don't include this in prompt.
- });
+ comboFilter = node.addWidget(
+ "string",
+ getName("control_filter_list", "controlFilterListName"),
+ "",
+ function () {},
+ {
+ serialize: false, // Don't include this in prompt.
+ }
+ );
widgets.push(comboFilter);
}
@@ -96,7 +140,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
targetWidget.value = value;
targetWidget.callback(value);
}
- } else { //number
+ } else {
+ //number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
@@ -119,32 +164,54 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
default:
break;
}
- /*check if values are over or under their respective
- * ranges and set them to min or max.*/
- if (targetWidget.value < min)
- targetWidget.value = min;
+ /*check if values are over or under their respective
+ * ranges and set them to min or max.*/
+ if (targetWidget.value < min) targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
targetWidget.callback(targetWidget.value);
}
- }
-
+ };
return widgets;
};
-function seedWidget(node, inputName, inputData, app) {
- const seed = ComfyWidgets.INT(node, inputName, inputData, app);
- const seedControl = addValueControlWidget(node, seed.widget, "randomize");
+function seedWidget(node, inputName, inputData, app, widgetName) {
+ const seed = createIntWidget(node, inputName, inputData, app, true);
+ const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData);
seed.widget.linkedWidgets = [seedControl];
return seed;
}
+
+function createIntWidget(node, inputName, inputData, app, isSeedInput) {
+ const control = inputData[1]?.control_after_generate;
+ if (!isSeedInput && control) {
+ return seedWidget(node, inputName, inputData, app, typeof control === "string" ? control : undefined);
+ }
+
+ let widgetType = isSlider(inputData[1]["display"], app);
+ const { val, config } = getNumberDefaults(inputData, 1, 0, true);
+ Object.assign(config, { precision: 0 });
+ return {
+ widget: node.addWidget(
+ widgetType,
+ inputName,
+ val,
+ function (v) {
+ const s = this.options.step / 10;
+ this.value = Math.round(v / s) * s;
+ },
+ config
+ ),
+ };
+}
+
function addMultilineWidget(node, name, opts, app) {
const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
- inputEl.placeholder = opts.placeholder || "";
+ inputEl.placeholder = opts.placeholder || name;
const widget = node.addDOMWidget(name, "customtext", inputEl, {
getValue() {
@@ -156,6 +223,10 @@ function addMultilineWidget(node, name, opts, app) {
});
widget.inputEl = inputEl;
+ inputEl.addEventListener("input", () => {
+ widget.callback?.(widget.value);
+ });
+
return { minWidth: 400, minHeight: 200, widget };
}
@@ -186,21 +257,7 @@ export const ComfyWidgets = {
}, config) };
},
INT(node, inputName, inputData, app) {
- let widgetType = isSlider(inputData[1]["display"], app);
- const { val, config } = getNumberDefaults(inputData, 1, 0, true);
- Object.assign(config, { precision: 0 });
- return {
- widget: node.addWidget(
- widgetType,
- inputName,
- val,
- function (v) {
- const s = this.options.step / 10;
- this.value = Math.round(v / s) * s;
- },
- config
- ),
- };
+ return createIntWidget(node, inputName, inputData, app);
},
BOOLEAN(node, inputName, inputData) {
let defaultVal = false;
@@ -245,10 +302,14 @@ export const ComfyWidgets = {
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
}
- return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
+ const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
+ if (inputData[1]?.control_after_generate) {
+ res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData);
+ }
+ return res;
},
IMAGEUPLOAD(node, inputName, inputData, app) {
- const imageWidget = node.widgets.find((w) => w.name === "image");
+ const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image"));
let uploadWidget;
function showImage(name) {
@@ -362,9 +423,10 @@ export const ComfyWidgets = {
document.body.append(fileInput);
// Create the button widget for selecting the files
- uploadWidget = node.addWidget("button", "choose file to upload", "image", () => {
+ uploadWidget = node.addWidget("button", inputName, "image", () => {
fileInput.click();
});
+ uploadWidget.label = "choose file to upload";
uploadWidget.serialize = false;
// Add handler to check if an image is being dragged over our node