From 04b308229ee59b5aebc0c78ea416e0b3ac22c146 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 31 May 2024 11:18:37 -0400 Subject: [PATCH] Small refactor of preview code. --- latent_preview.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/latent_preview.py b/latent_preview.py index b258fcf2..54aa233f 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,6 +11,13 @@ import logging MAX_PREVIEW_RESOLUTION = 512 +def preview_to_image(latent_image): + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) + + return Image.fromarray(latents_ubyte.numpy()) + class LatentPreviewer: def decode_latent_to_preview(self, x0): pass @@ -24,12 +31,8 @@ class TAESDPreviewerImpl(LatentPreviewer): self.taesd = taesd def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decode(x0[:1])[0].detach() - x_sample = 255. * torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = np.moveaxis(x_sample.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(x_sample.device)).numpy(), 0, 2) - - preview_image = Image.fromarray(x_sample) - return preview_image + x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) + return preview_to_image(x_sample) class Latent2RGBPreviewer(LatentPreviewer): @@ -39,13 +42,7 @@ class Latent2RGBPreviewer(LatentPreviewer): def decode_latent_to_preview(self, x0): self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors - - latents_ubyte = (((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) - - return Image.fromarray(latents_ubyte.numpy()) + return preview_to_image(latent_image) def get_previewer(device, latent_format):