Move image encoding outside of sampling loop for better preview perf.

This commit is contained in:
comfyanonymous 2023-07-19 17:37:27 -04:00
parent 39c58b227f
commit ccb6b70de1
3 changed files with 34 additions and 27 deletions

View File

@ -1,6 +1,5 @@
import torch import torch
from PIL import Image, ImageOps from PIL import Image
from io import BytesIO
import struct import struct
import numpy as np import numpy as np
from comfy.cli_args import args, LatentPreviewMethod from comfy.cli_args import args, LatentPreviewMethod
@ -15,26 +14,7 @@ class LatentPreviewer:
def decode_latent_to_preview_image(self, preview_format, x0): def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0) preview_image = self.decode_latent_to_preview(x0)
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
class TAESDPreviewerImpl(LatentPreviewer): class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd): def __init__(self, taesd):

View File

@ -92,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server): def hijack_progress(server):
def hook(value, total, preview_image_bytes): def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id) server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)

View File

@ -8,7 +8,7 @@ import uuid
import json import json
import glob import glob
import struct import struct
from PIL import Image from PIL import Image, ImageOps
from io import BytesIO from io import BytesIO
try: try:
@ -29,6 +29,7 @@ import comfy.model_management
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -498,7 +499,9 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
if isinstance(data, (bytes, bytearray)): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid) await self.send_bytes(event, data, sid)
else: else:
await self.send_json(event, data, sid) await self.send_json(event, data, sid)
@ -512,6 +515,30 @@ class PromptServer():
message.extend(data) message.extend(data)
return message return message
async def send_image(self, image_data, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None): async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)