diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py new file mode 100644 index 00000000..1d914fa9 --- /dev/null +++ b/comfy_extras/nodes_torch_compile.py @@ -0,0 +1,21 @@ +import torch + +class TorchCompileModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + EXPERIMENTAL = True + + def patch(self, model): + m = model.clone() + m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"))) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "TorchCompileModel": TorchCompileModel, +} diff --git a/nodes.py b/nodes.py index bbe73282..1f14aaf1 100644 --- a/nodes.py +++ b/nodes.py @@ -2102,6 +2102,7 @@ def init_builtin_extra_nodes(): "nodes_hunyuan.py", "nodes_flux.py", "nodes_lora_extract.py", + "nodes_torch_compile.py", ] import_failed = []