diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fae66612..3e6b1daa 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -44,11 +44,12 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") -class LatentPreviewType(enum.Enum): +class LatentPreviewMethod(enum.Enum): + Auto = "auto" Latent2RGB = "latent2rgb" TAESD = "taesd" parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") -parser.add_argument("--default-preview-method", type=str, default=LatentPreviewType.Latent2RGB, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.") +parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here b/models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 74c664bd..6266b6c0 100644 --- a/nodes.py +++ b/nodes.py @@ -24,7 +24,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils -from comfy.cli_args import args, LatentPreviewType +from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -1018,11 +1018,19 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, previewer = None if not args.disable_previews: # TODO previewer methods - if args.default_preview_method == LatentPreviewType.TAESD: - encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") - decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") - if encoder_path and decoder_path: - taesd = TAESD(encoder_path, decoder_path).to(device) + taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") + taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") + + method = args.default_preview_method + + if args.default_preview_method == LatentPreviewMethod.AUTO: + method = LatentPreviewMethod.Latent2RGB + if taesd_encoder_path and taesd_encoder_path: + method = LatentPreviewMethod.TAESD + + if method == LatentPreviewMethod.TAESD: + if taesd_encoder_path and taesd_encoder_path: + taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth")