diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 1d914fa9..1fe6f42c 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -4,6 +4,7 @@ class TorchCompileModel: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), + "backend": (["inductor", "cudagraphs"],), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -11,9 +12,9 @@ class TorchCompileModel: CATEGORY = "_for_testing" EXPERIMENTAL = True - def patch(self, model): + def patch(self, model, backend): m = model.clone() - m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"))) + m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend)) return (m, ) NODE_CLASS_MAPPINGS = {