preview method autodetection

This commit is contained in:
space-nuko 2023-06-05 18:59:10 -05:00
parent d5a28fadaa
commit 3e17971acb
3 changed files with 17 additions and 8 deletions

View File

@ -44,11 +44,12 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class LatentPreviewType(enum.Enum):
class LatentPreviewMethod(enum.Enum):
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.")
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewType.Latent2RGB, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.")
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")

View File

@ -24,7 +24,7 @@ import comfy.samplers
import comfy.sample
import comfy.sd
import comfy.utils
from comfy.cli_args import args, LatentPreviewType
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.clip_vision
@ -1018,11 +1018,19 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
previewer = None
if not args.disable_previews:
# TODO previewer methods
if args.default_preview_method == LatentPreviewType.TAESD:
encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth")
decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth")
if encoder_path and decoder_path:
taesd = TAESD(encoder_path, decoder_path).to(device)
taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth")
taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth")
method = args.default_preview_method
if args.default_preview_method == LatentPreviewMethod.AUTO:
method = LatentPreviewMethod.Latent2RGB
if taesd_encoder_path and taesd_encoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_encoder_path and taesd_encoder_path:
taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth")