diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js new file mode 100644 index 00000000..b82e55c3 --- /dev/null +++ b/tests-ui/tests/extensions.test.js @@ -0,0 +1,196 @@ +// @ts-check +/// +const { start } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("extensions", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + it("calls each extension hook", async () => { + const mockExtension = { + name: "TestExtension", + init: jest.fn(), + setup: jest.fn(), + addCustomNodeDefs: jest.fn(), + getCustomWidgets: jest.fn(), + beforeRegisterNodeDef: jest.fn(), + registerCustomNodes: jest.fn(), + loadedGraphNode: jest.fn(), + nodeCreated: jest.fn(), + beforeConfigureGraph: jest.fn(), + afterConfigureGraph: jest.fn(), + }; + + const { app, ez, graph } = await start({ + async preSetup(app) { + app.registerExtension(mockExtension); + }, + }); + + // Basic initialisation hooks should be called once, with app + expect(mockExtension.init).toHaveBeenCalledTimes(1); + expect(mockExtension.init).toHaveBeenCalledWith(app); + + // Adding custom node defs should be passed the full list of nodes + expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); + expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app); + const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]; + expect(defs).toHaveProperty("KSampler"); + expect(defs).toHaveProperty("LoadImage"); + + // Get custom widgets is called once and should return new widget types + expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); + expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app); + + // Before register node def will be called once per node type + const nodeNames = Object.keys(defs); + const nodeCount = nodeNames.length; + expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); + for (let i = 0; i < nodeCount; i++) { + // It should be send the JS class and the original JSON definition + const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; + const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; + + expect(nodeClass.name).toBe("ComfyNode"); + expect(nodeClass.comfyClass).toBe(nodeNames[i]); + expect(nodeDef.name).toBe(nodeNames[i]); + expect(nodeDef).toHaveProperty("input"); + expect(nodeDef).toHaveProperty("output"); + } + + // Register custom nodes is called once after registerNode defs to allow adding other frontend nodes + expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); + + // Before configure graph will be called here as the default graph is being loaded + expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1); + // it gets sent the graph data that is going to be loaded + const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]; + + // A node created is fired for each node constructor that is called + expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length); + for (let i = 0; i < graphData.nodes.length; i++) { + expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type); + } + + // Each node then calls loadedGraphNode to allow them to be updated + expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); + for (let i = 0; i < graphData.nodes.length; i++) { + expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type); + } + + // After configure is then called once all the setup is done + expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1); + + expect(mockExtension.setup).toHaveBeenCalledTimes(1); + expect(mockExtension.setup).toHaveBeenCalledWith(app); + + // Ensure hooks are called in the correct order + const callOrder = [ + "init", + "addCustomNodeDefs", + "getCustomWidgets", + "beforeRegisterNodeDef", + "registerCustomNodes", + "beforeConfigureGraph", + "nodeCreated", + "loadedGraphNode", + "afterConfigureGraph", + "setup", + ]; + for (let i = 1; i < callOrder.length; i++) { + const fn1 = mockExtension[callOrder[i - 1]]; + const fn2 = mockExtension[callOrder[i]]; + expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]); + } + + graph.clear(); + + // Ensure adding a new node calls the correct callback + ez.LoadImage(); + expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); + expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1); + expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage"); + + // Reload the graph to ensure correct hooks are fired + await graph.reload(); + + // These hooks should not be fired again + expect(mockExtension.init).toHaveBeenCalledTimes(1); + expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); + expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); + expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); + expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); + expect(mockExtension.setup).toHaveBeenCalledTimes(1); + + // These should be called again + expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2); + expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); + expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); + expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); + }); + + it("allows custom nodeDefs and widgets to be registered", async () => { + const widgetMock = jest.fn((node, inputName, inputData, app) => { + expect(node.constructor.comfyClass).toBe("TestNode"); + expect(inputName).toBe("test_input"); + expect(inputData[0]).toBe("CUSTOMWIDGET"); + expect(inputData[1]?.hello).toBe("world"); + expect(app).toStrictEqual(app); + + return { + widget: node.addWidget("button", inputName, "hello", () => {}), + }; + }); + + // Register our extension that adds a custom node + widget type + const mockExtension = { + name: "TestExtension", + addCustomNodeDefs: (nodeDefs) => { + nodeDefs["TestNode"] = { + output: [], + output_name: [], + output_is_list: [], + name: "TestNode", + display_name: "TestNode", + category: "Test", + input: { + required: { + test_input: ["CUSTOMWIDGET", { hello: "world" }], + }, + }, + }; + }, + getCustomWidgets: jest.fn(() => { + return { + CUSTOMWIDGET: widgetMock, + }; + }), + }; + + const { graph, ez } = await start({ + async preSetup(app) { + app.registerExtension(mockExtension); + }, + }); + + expect(mockExtension.getCustomWidgets).toBeCalledTimes(1); + + graph.clear(); + expect(widgetMock).toBeCalledTimes(0); + const node = ez.TestNode(); + expect(widgetMock).toBeCalledTimes(1); + + // Ensure our custom widget is created + expect(node.inputs.length).toBe(0); + expect(node.widgets.length).toBe(1); + const w = node.widgets[0].widget; + expect(w.name).toBe("test_input"); + expect(w.type).toBe("button"); + }); +}); diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index eeccdb3d..3a018f56 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -4,11 +4,11 @@ const lg = require("./litegraph"); /** * - * @param { Parameters[0] & { resetEnv?: boolean } } config + * @param { Parameters[0] & { resetEnv?: boolean, preSetup?(app): Promise } } config * @returns */ -export async function start(config = undefined) { - if(config?.resetEnv) { +export async function start(config = {}) { + if(config.resetEnv) { jest.resetModules(); jest.resetAllMocks(); lg.setup(global); @@ -16,6 +16,7 @@ export async function start(config = undefined) { mockApi(config); const { app } = require("../../web/scripts/app"); + config.preSetup?.(app); await app.setup(); return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } diff --git a/web/scripts/app.js b/web/scripts/app.js index a72e3002..861db16b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1654,6 +1654,7 @@ export class ComfyApp { if (missingNodeTypes.length) { this.showMissingNodesError(missingNodeTypes); } + await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes); } /**