reroute + primitive tests

This commit is contained in:
pythongosssss 2023-12-05 20:28:05 +00:00
parent 44265e0810
commit a99da6667f
2 changed files with 171 additions and 8 deletions

View File

@ -1,7 +1,13 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
const {
start,
makeNodeDef,
checkBeforeAndAfterReload,
assertNotNullOrUndefined,
createDefaultWorkflow,
} = require("../utils");
const lg = require("../utils/litegraph");
/**
@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
if (widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
@ -308,8 +314,8 @@ describe("widget inputs", () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
@ -330,7 +336,7 @@ describe("widget inputs", () => {
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
@ -380,7 +386,7 @@ describe("widget inputs", () => {
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
for (let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
@ -392,4 +398,160 @@ describe("widget inputs", () => {
control["afterQueued"]();
expect(value.value).toBe("B");
});
describe("reroutes", () => {
async function checkOutput(graph, values) {
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
4: {
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
class_type: "EmptyLatentImage",
},
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: values?.scheduler ?? "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
7: {
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
class_type: "SaveImage",
},
});
}
async function waitForWidget(node) {
// widgets are created slightly after the graph is ready
// hard to find an exact hook to get these so just wait for them to be ready
for (let i = 0; i < 10; i++) {
await new Promise((r) => setTimeout(r, 10));
if (node.widgets?.value) {
return;
}
}
}
it("can connect primitive via a reroute path to a widget input", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.sampler.widgets.scheduler.convertToInput();
nodes.save.widgets.filename_prefix.convertToInput();
let widthReroute = ez.Reroute();
let schedulerReroute = ez.Reroute();
let fileReroute = ez.Reroute();
let widthNext = widthReroute;
let schedulerNext = schedulerReroute;
let fileNext = fileReroute;
for (let i = 0; i < 5; i++) {
let next = ez.Reroute();
widthNext.outputs[0].connectTo(next.inputs[0]);
widthNext = next;
next = ez.Reroute();
schedulerNext.outputs[0].connectTo(next.inputs[0]);
schedulerNext = next;
next = ez.Reroute();
fileNext.outputs[0].connectTo(next.inputs[0]);
fileNext = next;
}
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
let widthPrimitive = ez.PrimitiveNode();
let schedulerPrimitive = ez.PrimitiveNode();
let filePrimitive = ez.PrimitiveNode();
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
expect(widthPrimitive.widgets.value.value).toBe(512);
widthPrimitive.widgets.value.value = 1024;
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
schedulerPrimitive.widgets.value.value = "simple";
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
filePrimitive.widgets.value.value = "ComfyTest";
await checkBeforeAndAfterReload(graph, async () => {
widthPrimitive = graph.find(widthPrimitive);
schedulerPrimitive = graph.find(schedulerPrimitive);
filePrimitive = graph.find(filePrimitive);
await waitForWidget(filePrimitive);
expect(widthPrimitive.widgets.length).toBe(2);
expect(schedulerPrimitive.widgets.length).toBe(3);
expect(filePrimitive.widgets.length).toBe(1);
await checkOutput(graph, {
width: 1024,
scheduler: "simple",
filename_prefix: "ComfyTest",
});
});
});
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.empty.widgets.height.convertToInput();
nodes.empty.widgets.batch_size.convertToInput();
let reroute = ez.Reroute();
let prevReroute = reroute;
for (let i = 0; i < 5; i++) {
const next = ez.Reroute();
prevReroute.outputs[0].connectTo(next.inputs[0]);
prevReroute = next;
}
const r1 = ez.Reroute(prevReroute.outputs[0]);
const r2 = ez.Reroute(prevReroute.outputs[0]);
const r3 = ez.Reroute(r2.outputs[0]);
const r4 = ez.Reroute(r2.outputs[0]);
r1.outputs[0].connectTo(nodes.empty.inputs.width);
r3.outputs[0].connectTo(nodes.empty.inputs.height);
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(reroute.inputs[0]);
expect(primitive.widgets.value.value).toBe(1);
primitive.widgets.value.value = 64;
await checkBeforeAndAfterReload(graph, async (r) => {
primitive = graph.find(primitive);
await waitForWidget(primitive);
// Ensure widget configs are merged
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
await checkOutput(graph, {
width: 64,
height: 64,
batch_size: 64,
});
});
});
});
});

View File

@ -117,7 +117,7 @@ export class EzOutput extends EzSlot {
const inp = input.input;
const inName = inp.name || inp.label || inp.type;
throw new Error(
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
this.output.name ?? this.output.type
}#${this.index}] failed.`
);
@ -179,6 +179,7 @@ export class EzWidget {
set value(v) {
this.widget.value = v;
this.widget.callback?.call?.(this.widget, v)
}
get isConvertedToInput() {
@ -319,7 +320,7 @@ export class EzGraph {
}
stringify() {
return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
return JSON.stringify(this.app.graph.serialize(), undefined);
}
/**