From ccb6b70de19fd277adb328945411732dbd8dede4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Jul 2023 17:37:27 -0400 Subject: [PATCH] Move image encoding outside of sampling loop for better preview perf. --- latent_preview.py | 24 ++---------------------- main.py | 6 +++--- server.py | 31 +++++++++++++++++++++++++++++-- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/latent_preview.py b/latent_preview.py index 833e6822..30c1d131 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -1,6 +1,5 @@ import torch -from PIL import Image, ImageOps -from io import BytesIO +from PIL import Image import struct import numpy as np from comfy.cli_args import args, LatentPreviewMethod @@ -15,26 +14,7 @@ class LatentPreviewer: def decode_latent_to_preview_image(self, preview_format, x0): preview_image = self.decode_latent_to_preview(x0) - - if hasattr(Image, 'Resampling'): - resampling = Image.Resampling.BILINEAR - else: - resampling = Image.ANTIALIAS - - preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling) - - preview_type = 1 - if preview_format == "JPEG": - preview_type = 1 - elif preview_format == "PNG": - preview_type = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", preview_type) - bytesIO.write(header) - preview_image.save(bytesIO, format=preview_format, quality=95) - preview_bytes = bytesIO.getvalue() - return preview_bytes + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): diff --git a/main.py b/main.py index b98f5d21..21f76b61 100644 --- a/main.py +++ b/main.py @@ -92,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): - def hook(value, total, preview_image_bytes): + def hook(value, total, preview_image): server.send_sync("progress", {"value": value, "max": total}, server.client_id) - if preview_image_bytes is not None: - server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) + if preview_image is not None: + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) diff --git a/server.py b/server.py index 9ca131ed..f61b11a9 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import uuid import json import glob import struct -from PIL import Image +from PIL import Image, ImageOps from io import BytesIO try: @@ -29,6 +29,7 @@ import comfy.model_management class BinaryEventTypes: PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 async def send_socket_catch_exception(function, message): try: @@ -498,7 +499,9 @@ class PromptServer(): return prompt_info async def send(self, event, data, sid=None): - if isinstance(data, (bytes, bytearray)): + if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + await self.send_image(data, sid=sid) + elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) else: await self.send_json(event, data, sid) @@ -512,6 +515,30 @@ class PromptServer(): message.extend(data) return message + async def send_image(self, image_data, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + type_num = 1 + if image_type == "JPEG": + type_num = 1 + elif image_type == "PNG": + type_num = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", type_num) + bytesIO.write(header) + image.save(bytesIO, format=image_type, quality=95, compress_level=4) + preview_bytes = bytesIO.getvalue() + await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data)