From d0b7ab88ba0f1cb4ab16e0425f5229e60c934536 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 12 Sep 2024 05:23:32 -0400 Subject: [PATCH] Add a simple experimental TorchCompileModel node. It probably only works on Linux. For maximum speed on Flux with Nvidia 40 series/ada and newer try using this node with fp8_e4m3fn and the --fast argument. --- comfy_extras/nodes_torch_compile.py | 21 +++++++++++++++++++++ nodes.py | 1 + 2 files changed, 22 insertions(+) create mode 100644 comfy_extras/nodes_torch_compile.py diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py new file mode 100644 index 00000000..1d914fa9 --- /dev/null +++ b/comfy_extras/nodes_torch_compile.py @@ -0,0 +1,21 @@ +import torch + +class TorchCompileModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + EXPERIMENTAL = True + + def patch(self, model): + m = model.clone() + m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"))) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "TorchCompileModel": TorchCompileModel, +} diff --git a/nodes.py b/nodes.py index bbe73282..1f14aaf1 100644 --- a/nodes.py +++ b/nodes.py @@ -2102,6 +2102,7 @@ def init_builtin_extra_nodes(): "nodes_hunyuan.py", "nodes_flux.py", "nodes_lora_extract.py", + "nodes_torch_compile.py", ] import_failed = []