diff --git a/nodes.py b/nodes.py index e46aed82..27a329c6 100644 --- a/nodes.py +++ b/nodes.py @@ -1673,6 +1673,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "VAEEncodeTiled": "VAE Encode (Tiled)", } +EXTENSION_WEB_DIRS = {} + def load_custom_node(module_path, ignore=set()): module_name = os.path.basename(module_path) if os.path.isfile(module_path): @@ -1681,11 +1683,20 @@ def load_custom_node(module_path, ignore=set()): try: if os.path.isfile(module_path): module_spec = importlib.util.spec_from_file_location(module_name, module_path) + module_dir = os.path.split(module_path)[0] else: module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) + module_dir = module_path + module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module module_spec.loader.exec_module(module) + + if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None: + web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY"))) + if os.path.isdir(web_dir): + EXTENSION_WEB_DIRS[module_name] = web_dir + if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: for name in module.NODE_CLASS_MAPPINGS: if name not in ignore: diff --git a/server.py b/server.py index fab33be3..344847b3 100644 --- a/server.py +++ b/server.py @@ -5,6 +5,7 @@ import nodes import folder_paths import execution import uuid +import urllib import json import glob import struct @@ -67,6 +68,8 @@ class PromptServer(): mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' + + self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop self.messages = asyncio.Queue() @@ -123,8 +126,18 @@ class PromptServer(): @routes.get("/extensions") async def get_extensions(request): - files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) - return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) + files = glob.glob(os.path.join( + self.web_root, 'extensions/**/*.js'), recursive=True) + + extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) + + for name, dir in nodes.EXTENSION_WEB_DIRS.items(): + files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True) + extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( + name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) + print(extensions) + + return web.json_response(extensions) def get_dir_by_type(dir_type): if dir_type is None: @@ -492,6 +505,12 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) + + for name, dir in nodes.EXTENSION_WEB_DIRS.items(): + self.app.add_routes([ + web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True), + ]) + self.app.add_routes([ web.static('/', self.web_root, follow_symlinks=True), ])