Make new LATENT_PREVIEWER type for declaring KSampler preview methods
This commit is contained in:
parent
a9fa2d3727
commit
f326a0a468
52
nodes.py
52
nodes.py
|
@ -34,6 +34,11 @@ import importlib
|
|||
import folder_paths
|
||||
|
||||
|
||||
class LatentPreviewer:
|
||||
def decode_latent_to_preview(self, device, x0):
|
||||
pass
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
|
@ -282,6 +287,27 @@ class TAESDEncode:
|
|||
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
|
||||
return ({"samples": samples}, )
|
||||
|
||||
class TAESDPreviewerImpl(LatentPreviewer):
|
||||
def __init__(self, taesd):
|
||||
self.taesd = taesd
|
||||
|
||||
def decode_latent_to_preview(self, device, x0):
|
||||
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
|
||||
x_sample = self.taesd.unscale_latents(x_sample) # returns value in [-2, 2]
|
||||
x_sample = x_sample * 0.5
|
||||
return x_sample
|
||||
|
||||
class TAESDPreviewer:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "taesd": ("TAESD", ), }}
|
||||
RETURN_TYPES = ("LATENT_PREVIEWER",)
|
||||
FUNCTION = "make_previewer"
|
||||
|
||||
CATEGORY = "latent/previewer"
|
||||
|
||||
def make_previewer(self, taesd):
|
||||
return (TAESDPreviewerImpl(taesd), )
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
|
@ -986,10 +1012,8 @@ class SetLatentNoiseMask:
|
|||
return (s,)
|
||||
|
||||
|
||||
def decode_latent_to_preview_image(taesd, device, preview_format, x0):
|
||||
x_sample = taesd.decoder(x0.to(device))[0].detach()
|
||||
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
|
||||
x_sample = x_sample * 0.5
|
||||
def decode_latent_to_preview_image(previewer, device, preview_format, x0):
|
||||
x_sample = previewer.decode_latent_to_preview(device, x0)
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
|
@ -1015,7 +1039,7 @@ def decode_latent_to_preview_image(taesd, device, preview_format, x0):
|
|||
return preview_bytes
|
||||
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None):
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
latent_image = latent["samples"]
|
||||
|
||||
|
@ -1036,8 +1060,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
preview_bytes = None
|
||||
if taesd:
|
||||
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0)
|
||||
if previewer:
|
||||
preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
|
@ -1063,7 +1087,7 @@ class KSampler:
|
|||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"taesd": ("TAESD",)
|
||||
"previewer": ("LATENT_PREVIEWER",)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
|
@ -1071,8 +1095,8 @@ class KSampler:
|
|||
|
||||
CATEGORY = "sampling"
|
||||
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None):
|
||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd)
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None):
|
||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer)
|
||||
|
||||
class KSamplerAdvanced:
|
||||
@classmethod
|
||||
|
@ -1093,7 +1117,7 @@ class KSamplerAdvanced:
|
|||
"return_with_leftover_noise": (["disable", "enable"], ),
|
||||
},
|
||||
"optional": {
|
||||
"taesd": ("TAESD",)
|
||||
"previewer": ("LATENT_PREVIEWER",)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
|
@ -1101,14 +1125,14 @@ class KSamplerAdvanced:
|
|||
|
||||
CATEGORY = "sampling"
|
||||
|
||||
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None):
|
||||
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None):
|
||||
force_full_denoise = True
|
||||
if return_with_leftover_noise == "enable":
|
||||
force_full_denoise = False
|
||||
disable_noise = False
|
||||
if add_noise == "disable":
|
||||
disable_noise = True
|
||||
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd)
|
||||
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer)
|
||||
|
||||
class SaveImage:
|
||||
def __init__(self):
|
||||
|
@ -1369,6 +1393,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"VAELoader": VAELoader,
|
||||
"TAESDDecode": TAESDDecode,
|
||||
"TAESDEncode": TAESDEncode,
|
||||
"TAESDPreviewer": TAESDPreviewer,
|
||||
"TAESDLoader": TAESDLoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
|
@ -1425,6 +1450,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"VAELoader": "Load VAE",
|
||||
"TAESDLoader": "Load TAESD",
|
||||
"TAESDPreviewer": "TAESD Previewer",
|
||||
"LoraLoader": "Load LoRA",
|
||||
"CLIPLoader": "Load CLIP",
|
||||
"ControlNetLoader": "Load ControlNet Model",
|
||||
|
|
Loading…
Reference in New Issue