Add a simple experimental TorchCompileModel node.

It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
This commit is contained in:
comfyanonymous 2024-09-12 05:23:32 -04:00
parent 405b529545
commit d0b7ab88ba
2 changed files with 22 additions and 0 deletions

View File

@ -0,0 +1,21 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -2102,6 +2102,7 @@ def init_builtin_extra_nodes():
"nodes_hunyuan.py",
"nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
]
import_failed = []