Preview sampled images with TAESD

This commit is contained in:
space-nuko 2023-05-30 20:43:29 -05:00
parent 2ec980bb9f
commit b4f434ee66
8 changed files with 324 additions and 52 deletions

65
comfy/taesd/taesd.py Normal file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import torch
import torch.nn as nn
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def Encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)
def Decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"))
@staticmethod
def scale_latents(x):
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

View File

@ -197,14 +197,14 @@ class ProgressBar:
self.current = 0 self.current = 0
self.hook = PROGRESS_BAR_HOOK self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
self.total = total self.total = total
if value > self.total: if value > self.total:
value = self.total value = self.total
self.current = value self.current = value
if self.hook is not None: if self.hook is not None:
self.hook(self.current, self.total) self.hook(self.current, self.total, preview)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -26,6 +26,7 @@ import yaml
import execution import execution
import folder_paths import folder_paths
import server import server
from server import BinaryEventTypes
from nodes import init_custom_nodes from nodes import init_custom_nodes
@ -40,8 +41,11 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
def hook(value, total): def hook(value, total, preview_bytes_jpeg):
server.send_sync("progress", { "value": value, "max": total}, server.client_id) server.send_sync("progress", { "value": value, "max": total}, server.client_id)
if preview_bytes_jpeg is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes_jpeg, server.client_id)
pass
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp(): def cleanup_temp():

119
nodes.py
View File

@ -7,6 +7,8 @@ import hashlib
import traceback import traceback
import math import math
import time import time
import struct
from io import BytesIO
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -22,6 +24,7 @@ import comfy.samplers
import comfy.sample import comfy.sample
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
from comfy.taesd.taesd import TAESD
import comfy.clip_vision import comfy.clip_vision
@ -38,6 +41,7 @@ def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=8192
MAX_PREVIEW_RESOLUTION = 512
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -171,6 +175,21 @@ class VAEDecodeTiled:
def decode(self, vae, samples): def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), ) return (vae.decode_tiled(samples["samples"]), )
class TAESDDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "taesd": ("TAESD", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "latent"
def decode(self, taesd, samples):
device = comfy.model_management.get_torch_device()
# [B, C, H, W] -> [B, H, W, C]
pixels = taesd.decoder(samples["samples"].to(device)).permute(0, 2, 3, 1).detach().clamp(0, 1)
return (pixels, )
class VAEEncode: class VAEEncode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -248,6 +267,21 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class TAESDEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "taesd": ("TAESD", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent"
def encode(self, taesd, pixels):
device = comfy.model_management.get_torch_device()
# [B, H, W, C] -> [B, C, H, W]
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
return ({"samples": samples}, )
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
@ -464,6 +498,26 @@ class VAELoader:
vae = comfy.sd.VAE(ckpt_path=vae_path) vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,) return (vae,)
class TAESDLoader:
@classmethod
def INPUT_TYPES(s):
model_list = folder_paths.get_filename_list("taesd")
return {"required": {
"encoder_name": (model_list, { "default": "taesd_encoder.pth" }),
"decoder_name": (model_list, { "default": "taesd_decoder.pth" })
}}
RETURN_TYPES = ("TAESD",)
FUNCTION = "load_taesd"
CATEGORY = "loaders"
def load_taesd(self, encoder_name, decoder_name):
device = comfy.model_management.get_torch_device()
encoder_path = folder_paths.get_full_path("taesd", encoder_name)
decoder_path = folder_paths.get_full_path("taesd", decoder_name)
taesd = TAESD(encoder_path, decoder_path).to(device)
return (taesd,)
class ControlNetLoader: class ControlNetLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -931,7 +985,37 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
def decode_latent_to_preview_image(taesd, device, preview_format, x0):
x_sample = taesd.decoder(x0.to(device))[0].detach()
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
preview_image = Image.fromarray(x_sample)
if preview_image.size[0] > MAX_PREVIEW_RESOLUTION or preview_image.size[1] > MAX_PREVIEW_RESOLUTION:
preview_image.thumbnail((MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format)
preview_bytes = bytesIO.getvalue()
return preview_bytes
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps) preview_bytes = None
if taesd:
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
@ -970,15 +1061,18 @@ class KSampler:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ), "latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} },
"optional": {
"taesd": ("TAESD",)
}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd)
class KSamplerAdvanced: class KSamplerAdvanced:
@classmethod @classmethod
@ -997,21 +1091,24 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ), "return_with_leftover_noise": (["disable", "enable"], ),
}} },
"optional": {
"taesd": ("TAESD",)
}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None):
force_full_denoise = True force_full_denoise = True
if return_with_leftover_noise == "enable": if return_with_leftover_noise == "enable":
force_full_denoise = False force_full_denoise = False
disable_noise = False disable_noise = False
if add_noise == "disable": if add_noise == "disable":
disable_noise = True disable_noise = True
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd)
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncode": VAEEncode, "VAEEncode": VAEEncode,
"VAEEncodeForInpaint": VAEEncodeForInpaint, "VAEEncodeForInpaint": VAEEncodeForInpaint,
"VAELoader": VAELoader, "VAELoader": VAELoader,
"TAESDDecode": TAESDDecode,
"TAESDEncode": TAESDEncode,
"TAESDLoader": TAESDLoader,
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,
"LatentUpscaleBy": LatentUpscaleBy, "LatentUpscaleBy": LatentUpscaleBy,
@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoader": "Load Checkpoint (With Config)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"TAESDLoader": "Load TAESD",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",
"CLIPLoader": "Load CLIP", "CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model", "ControlNetLoader": "Load ControlNet Model",
@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SetLatentNoiseMask": "Set Latent Noise Mask", "SetLatentNoiseMask": "Set Latent Noise Mask",
"VAEDecode": "VAE Decode", "VAEDecode": "VAE Decode",
"VAEEncode": "VAE Encode", "VAEEncode": "VAE Encode",
"TAESDDecode": "TAESD Decode",
"TAESDEncode": "TAESD Encode",
"LatentRotate": "Rotate Latent", "LatentRotate": "Rotate Latent",
"LatentFlip": "Flip Latent", "LatentFlip": "Flip Latent",
"LatentCrop": "Crop Latent", "LatentCrop": "Crop Latent",

View File

@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import struct
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -25,6 +26,11 @@ from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request) response: web.Response = await handler(request)
@ -457,16 +463,37 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
message = {"type": event, "data": data} if isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
if isinstance(message, str) == False: else:
message = json.dumps(message) await self.send_json(event, data, sid)
def encode_bytes(self, event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): for ws in self.sockets.values():
await ws.send_str(message) await ws.send_bytes(message)
elif sid in self.sockets: elif sid in self.sockets:
await self.sockets[sid].send_str(message) await self.sockets[sid].send_bytes(message)
async def send_json(self, event, data, sid=None):
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
await ws.send_json(message)
elif sid in self.sockets:
await self.sockets[sid].send_json(message)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(

View File

@ -21,6 +21,7 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple "MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow "STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red "VAE": "#FF6E6E", // bright red
"TAESD": "#DCC274", // cheesecake
}, },
"litegraph_base": { "litegraph_base": {
"NODE_TITLE_COLOR": "#999", "NODE_TITLE_COLOR": "#999",

View File

@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this.socket = new WebSocket( this.socket = new WebSocket(
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
); );
this.socket.binaryType = "arraybuffer";
this.socket.addEventListener("open", () => { this.socket.addEventListener("open", () => {
opened = true; opened = true;
@ -70,39 +71,66 @@ class ComfyApi extends EventTarget {
this.socket.addEventListener("message", (event) => { this.socket.addEventListener("message", (event) => {
try { try {
const msg = JSON.parse(event.data); if (event.data instanceof ArrayBuffer) {
switch (msg.type) { const view = new DataView(event.data);
case "status": const eventType = view.getUint32(0);
if (msg.data.sid) { const buffer = event.data.slice(4);
this.clientId = msg.data.sid; console.error("BINARY", eventType);
window.name = this.clientId; switch (eventType) {
case 1:
const view2 = new DataView(event.data);
const imageType = view2.getUint32(0)
let imageMime
switch (imageType) {
case 1:
default:
imageMime = "image/jpeg";
break;
case 2:
imageMime = "image/png"
} }
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); const jpegBlob = new Blob([buffer.slice(4)], { type: imageMime });
break; this.dispatchEvent(new CustomEvent("b_preview", { detail: jpegBlob }));
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break; break;
default: default:
if (this.#registered.has(msg.type)) { throw new Error(`Unknown binary websocket message of type ${eventType}`);
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); }
} else { }
throw new Error("Unknown message type"); else {
} const msg = JSON.parse(event.data);
switch (msg.type) {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
window.name = this.clientId;
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else {
throw new Error(`Unknown message type ${msg.type}`);
}
}
} }
} catch (error) { } catch (error) {
console.warn("Unhandled message:", event.data); console.warn("Unhandled message:", event.data, error);
} }
}); });
} }

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/ */
this.nodeOutputs = {}; this.nodeOutputs = {};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this.nodePreviewImages = {};
/** /**
* If the shift key on the keyboard is pressed * If the shift key on the keyboard is pressed
* @type {boolean} * @type {boolean}
@ -367,29 +373,52 @@ export class ComfyApp {
node.prototype.onDrawBackground = function (ctx) { node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) { if (!this.flags.collapsed) {
let imgURLs = []
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
if (output && output.images) { if (output && output.images) {
if (this.images !== output.images) { if (this.images !== output.images) {
this.images = output.images; this.images = output.images;
this.imgs = null; imagesChanged = true;
this.imageIndex = null; imgURLs = imgURLs.concat(output.images.map(params => {
return "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam();
}))
}
}
const preview = app.nodePreviewImages[this.id + ""]
if (this.preview !== preview) {
this.preview = preview
imagesChanged = true;
if (preview != null) {
imgURLs.push(preview);
}
}
if (imagesChanged) {
this.imageIndex = null;
if (imgURLs.length > 0) {
Promise.all( Promise.all(
output.images.map((src) => { imgURLs.map((src) => {
return new Promise((r) => { return new Promise((r) => {
const img = new Image(); const img = new Image();
img.onload = () => r(img); img.onload = () => r(img);
img.onerror = () => r(null); img.onerror = () => r(null);
img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); img.src = src
}); });
}) })
).then((imgs) => { ).then((imgs) => {
if (this.images === output.images) { if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
this.imgs = imgs.filter(Boolean); this.imgs = imgs.filter(Boolean);
this.setSizeForImage?.(); this.setSizeForImage?.();
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
} }
}); });
} }
else {
this.imgs = null;
}
} }
if (this.imgs && this.imgs.length) { if (this.imgs && this.imgs.length) {
@ -901,17 +930,20 @@ export class ComfyApp {
this.progress = null; this.progress = null;
this.runningNodeId = detail; this.runningNodeId = detail;
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId]
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output; this.nodeOutputs[detail.node] = detail.output;
const node = this.graph.getNodeById(detail.node); const node = this.graph.getNodeById(detail.node);
if (node?.onExecuted) { if (node) {
node.onExecuted(detail.output); if (node.onExecuted)
node.onExecuted(detail.output);
} }
}); });
api.addEventListener("execution_start", ({ detail }) => { api.addEventListener("execution_start", ({ detail }) => {
this.runningNodeId = null;
this.lastExecutionError = null this.lastExecutionError = null
}); });
@ -922,6 +954,16 @@ export class ComfyApp {
this.canvas.draw(true, true); this.canvas.draw(true, true);
}); });
api.addEventListener("b_preview", ({ detail }) => {
const id = this.runningNodeId
if (id == null)
return;
const blob = detail
const blobUrl = URL.createObjectURL(blob)
this.nodePreviewImages[id] = [blobUrl]
});
api.init(); api.init();
} }
@ -1465,8 +1507,10 @@ export class ComfyApp {
*/ */
clean() { clean() {
this.nodeOutputs = {}; this.nodeOutputs = {};
this.nodePreviewImages = {}
this.lastPromptError = null; this.lastPromptError = null;
this.lastExecutionError = null; this.lastExecutionError = null;
this.runningNodeId = null;
} }
} }