diff --git a/nodes.py b/nodes.py index a230f725..8d8b153f 100644 --- a/nodes.py +++ b/nodes.py @@ -1889,7 +1889,29 @@ NODE_DISPLAY_NAME_MAPPINGS = { EXTENSION_WEB_DIRS = {} -def load_custom_node(module_path, ignore=set()): +def get_relative_module_name(module_path: str) -> str: + """ + Returns the module name based on the given module path. + Examples: + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "custom_nodes.my_custom_node" + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "custom_nodes.my_custom_node" + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "custom_nodes.my_custom_node" + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "custom_nodes.my_custom_node" + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "custom_nodes.my_custom_node" + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "custom_nodes.my_custom_node" + get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes.my + Args: + module_path (str): The path of the module. + Returns: + str: The module name. + """ + relative_path = os.path.relpath(module_path, folder_paths.base_path) + if os.path.isfile(module_path): + relative_path = os.path.splitext(relative_path)[0] + return relative_path.replace(os.sep, '.') + + +def load_custom_node(module_path: str, ignore=set()) -> bool: module_name = os.path.basename(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) @@ -1913,9 +1935,10 @@ def load_custom_node(module_path, ignore=set()): EXTENSION_WEB_DIRS[module_name] = web_dir if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: - for name in module.NODE_CLASS_MAPPINGS: + for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): if name not in ignore: - NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name] + NODE_CLASS_MAPPINGS[name] = node_cls + node_cls.RELATIVE_PYTHON_MODULE = get_relative_module_name(module_path) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True diff --git a/server.py b/server.py index 38b1bab8..ce7c7532 100644 --- a/server.py +++ b/server.py @@ -416,6 +416,7 @@ class PromptServer(): info['name'] = node_class info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' + info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: info['output_node'] = True