Preview sampled images with TAESD
This commit is contained in:
parent
2ec980bb9f
commit
b4f434ee66
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tiny AutoEncoder for Stable Diffusion
|
||||
(DNN for encoding / decoding SD's latent space)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
def Encoder():
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
)
|
||||
|
||||
def Decoder():
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
|
@ -197,14 +197,14 @@ class ProgressBar:
|
|||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
|
||||
def update_absolute(self, value, total=None):
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total)
|
||||
self.hook(self.current, self.total, preview)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
|
6
main.py
6
main.py
|
@ -26,6 +26,7 @@ import yaml
|
|||
import execution
|
||||
import folder_paths
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from nodes import init_custom_nodes
|
||||
|
||||
|
||||
|
@ -40,8 +41,11 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
def hijack_progress(server):
|
||||
def hook(value, total):
|
||||
def hook(value, total, preview_bytes_jpeg):
|
||||
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
|
||||
if preview_bytes_jpeg is not None:
|
||||
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes_jpeg, server.client_id)
|
||||
pass
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
def cleanup_temp():
|
||||
|
|
119
nodes.py
119
nodes.py
|
@ -7,6 +7,8 @@ import hashlib
|
|||
import traceback
|
||||
import math
|
||||
import time
|
||||
import struct
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
@ -22,6 +24,7 @@ import comfy.samplers
|
|||
import comfy.sample
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
from comfy.taesd.taesd import TAESD
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
|
@ -38,6 +41,7 @@ def interrupt_processing(value=True):
|
|||
comfy.model_management.interrupt_current_processing(value)
|
||||
|
||||
MAX_RESOLUTION=8192
|
||||
MAX_PREVIEW_RESOLUTION = 512
|
||||
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
|
@ -171,6 +175,21 @@ class VAEDecodeTiled:
|
|||
def decode(self, vae, samples):
|
||||
return (vae.decode_tiled(samples["samples"]), )
|
||||
|
||||
class TAESDDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "taesd": ("TAESD", )}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def decode(self, taesd, samples):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
# [B, C, H, W] -> [B, H, W, C]
|
||||
pixels = taesd.decoder(samples["samples"].to(device)).permute(0, 2, 3, 1).detach().clamp(0, 1)
|
||||
return (pixels, )
|
||||
|
||||
class VAEEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
@ -248,6 +267,21 @@ class VAEEncodeForInpaint:
|
|||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
class TAESDEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "taesd": ("TAESD", )}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def encode(self, taesd, pixels):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
# [B, H, W, C] -> [B, C, H, W]
|
||||
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
|
||||
return ({"samples": samples}, )
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
|
@ -464,6 +498,26 @@ class VAELoader:
|
|||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
return (vae,)
|
||||
|
||||
class TAESDLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
model_list = folder_paths.get_filename_list("taesd")
|
||||
return {"required": {
|
||||
"encoder_name": (model_list, { "default": "taesd_encoder.pth" }),
|
||||
"decoder_name": (model_list, { "default": "taesd_decoder.pth" })
|
||||
}}
|
||||
RETURN_TYPES = ("TAESD",)
|
||||
FUNCTION = "load_taesd"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_taesd(self, encoder_name, decoder_name):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
encoder_path = folder_paths.get_full_path("taesd", encoder_name)
|
||||
decoder_path = folder_paths.get_full_path("taesd", decoder_name)
|
||||
taesd = TAESD(encoder_path, decoder_path).to(device)
|
||||
return (taesd,)
|
||||
|
||||
class ControlNetLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
@ -931,7 +985,37 @@ class SetLatentNoiseMask:
|
|||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
return (s,)
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
|
||||
def decode_latent_to_preview_image(taesd, device, preview_format, x0):
|
||||
x_sample = taesd.decoder(x0.to(device))[0].detach()
|
||||
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
|
||||
x_sample = x_sample * 0.5
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
preview_image = Image.fromarray(x_sample)
|
||||
|
||||
if preview_image.size[0] > MAX_PREVIEW_RESOLUTION or preview_image.size[1] > MAX_PREVIEW_RESOLUTION:
|
||||
preview_image.thumbnail((MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
|
||||
|
||||
preview_type = 1
|
||||
if preview_format == "JPEG":
|
||||
preview_type = 1
|
||||
elif preview_format == "PNG":
|
||||
preview_type = 2
|
||||
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", preview_type)
|
||||
bytesIO.write(header)
|
||||
preview_image.save(bytesIO, format=preview_format)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
|
||||
return preview_bytes
|
||||
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
latent_image = latent["samples"]
|
||||
|
||||
|
@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
preview_format = "JPEG"
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
pbar.update_absolute(step + 1, total_steps)
|
||||
preview_bytes = None
|
||||
if taesd:
|
||||
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
|
@ -970,15 +1061,18 @@ class KSampler:
|
|||
"negative": ("CONDITIONING", ),
|
||||
"latent_image": ("LATENT", ),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
},
|
||||
"optional": {
|
||||
"taesd": ("TAESD",)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling"
|
||||
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
|
||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None):
|
||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd)
|
||||
|
||||
class KSamplerAdvanced:
|
||||
@classmethod
|
||||
|
@ -997,21 +1091,24 @@ class KSamplerAdvanced:
|
|||
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
"return_with_leftover_noise": (["disable", "enable"], ),
|
||||
}}
|
||||
},
|
||||
"optional": {
|
||||
"taesd": ("TAESD",)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling"
|
||||
|
||||
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
|
||||
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None):
|
||||
force_full_denoise = True
|
||||
if return_with_leftover_noise == "enable":
|
||||
force_full_denoise = False
|
||||
disable_noise = False
|
||||
if add_noise == "disable":
|
||||
disable_noise = True
|
||||
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
|
||||
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd)
|
||||
|
||||
class SaveImage:
|
||||
def __init__(self):
|
||||
|
@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = {
|
|||
"VAEEncode": VAEEncode,
|
||||
"VAEEncodeForInpaint": VAEEncodeForInpaint,
|
||||
"VAELoader": VAELoader,
|
||||
"TAESDDecode": TAESDDecode,
|
||||
"TAESDEncode": TAESDEncode,
|
||||
"TAESDLoader": TAESDLoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"LatentUpscaleBy": LatentUpscaleBy,
|
||||
|
@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"CheckpointLoader": "Load Checkpoint (With Config)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"VAELoader": "Load VAE",
|
||||
"TAESDLoader": "Load TAESD",
|
||||
"LoraLoader": "Load LoRA",
|
||||
"CLIPLoader": "Load CLIP",
|
||||
"ControlNetLoader": "Load ControlNet Model",
|
||||
|
@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||
"VAEDecode": "VAE Decode",
|
||||
"VAEEncode": "VAE Encode",
|
||||
"TAESDDecode": "TAESD Decode",
|
||||
"TAESDEncode": "TAESD Encode",
|
||||
"LatentRotate": "Rotate Latent",
|
||||
"LatentFlip": "Flip Latent",
|
||||
"LatentCrop": "Crop Latent",
|
||||
|
|
39
server.py
39
server.py
|
@ -7,6 +7,7 @@ import execution
|
|||
import uuid
|
||||
import json
|
||||
import glob
|
||||
import struct
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
|
@ -25,6 +26,11 @@ from comfy.cli_args import args
|
|||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
|
@ -457,16 +463,37 @@ class PromptServer():
|
|||
return prompt_info
|
||||
|
||||
async def send(self, event, data, sid=None):
|
||||
message = {"type": event, "data": data}
|
||||
|
||||
if isinstance(message, str) == False:
|
||||
message = json.dumps(message)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
await self.send_bytes(event, data, sid)
|
||||
else:
|
||||
await self.send_json(event, data, sid)
|
||||
|
||||
def encode_bytes(self, event, data):
|
||||
if not isinstance(event, int):
|
||||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||
|
||||
packed = struct.pack(">I", event)
|
||||
message = bytearray(packed)
|
||||
message.extend(data)
|
||||
return message
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
if sid is None:
|
||||
for ws in self.sockets.values():
|
||||
await ws.send_str(message)
|
||||
await ws.send_bytes(message)
|
||||
elif sid in self.sockets:
|
||||
await self.sockets[sid].send_str(message)
|
||||
await self.sockets[sid].send_bytes(message)
|
||||
|
||||
async def send_json(self, event, data, sid=None):
|
||||
message = {"type": event, "data": data}
|
||||
|
||||
if sid is None:
|
||||
for ws in self.sockets.values():
|
||||
await ws.send_json(message)
|
||||
elif sid in self.sockets:
|
||||
await self.sockets[sid].send_json(message)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
self.loop.call_soon_threadsafe(
|
||||
|
|
|
@ -21,6 +21,7 @@ const colorPalettes = {
|
|||
"MODEL": "#B39DDB", // light lavender-purple
|
||||
"STYLE_MODEL": "#C2FFAE", // light green-yellow
|
||||
"VAE": "#FF6E6E", // bright red
|
||||
"TAESD": "#DCC274", // cheesecake
|
||||
},
|
||||
"litegraph_base": {
|
||||
"NODE_TITLE_COLOR": "#999",
|
||||
|
|
|
@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
|
|||
this.socket = new WebSocket(
|
||||
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
|
||||
);
|
||||
this.socket.binaryType = "arraybuffer";
|
||||
|
||||
this.socket.addEventListener("open", () => {
|
||||
opened = true;
|
||||
|
@ -70,39 +71,66 @@ class ComfyApi extends EventTarget {
|
|||
|
||||
this.socket.addEventListener("message", (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
switch (msg.type) {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
window.name = this.clientId;
|
||||
if (event.data instanceof ArrayBuffer) {
|
||||
const view = new DataView(event.data);
|
||||
const eventType = view.getUint32(0);
|
||||
const buffer = event.data.slice(4);
|
||||
console.error("BINARY", eventType);
|
||||
switch (eventType) {
|
||||
case 1:
|
||||
const view2 = new DataView(event.data);
|
||||
const imageType = view2.getUint32(0)
|
||||
let imageMime
|
||||
switch (imageType) {
|
||||
case 1:
|
||||
default:
|
||||
imageMime = "image/jpeg";
|
||||
break;
|
||||
case 2:
|
||||
imageMime = "image/png"
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
case "progress":
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_start":
|
||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_error":
|
||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||
const jpegBlob = new Blob([buffer.slice(4)], { type: imageMime });
|
||||
this.dispatchEvent(new CustomEvent("b_preview", { detail: jpegBlob }));
|
||||
break;
|
||||
default:
|
||||
if (this.#registered.has(msg.type)) {
|
||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||
} else {
|
||||
throw new Error("Unknown message type");
|
||||
}
|
||||
throw new Error(`Unknown binary websocket message of type ${eventType}`);
|
||||
}
|
||||
}
|
||||
else {
|
||||
const msg = JSON.parse(event.data);
|
||||
switch (msg.type) {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
window.name = this.clientId;
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
case "progress":
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_start":
|
||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_error":
|
||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||
break;
|
||||
default:
|
||||
if (this.#registered.has(msg.type)) {
|
||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||
} else {
|
||||
throw new Error(`Unknown message type ${msg.type}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("Unhandled message:", event.data);
|
||||
console.warn("Unhandled message:", event.data, error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -44,6 +44,12 @@ export class ComfyApp {
|
|||
*/
|
||||
this.nodeOutputs = {};
|
||||
|
||||
/**
|
||||
* Stores the preview image data for each node
|
||||
* @type {Record<string, Image>}
|
||||
*/
|
||||
this.nodePreviewImages = {};
|
||||
|
||||
/**
|
||||
* If the shift key on the keyboard is pressed
|
||||
* @type {boolean}
|
||||
|
@ -367,29 +373,52 @@ export class ComfyApp {
|
|||
|
||||
node.prototype.onDrawBackground = function (ctx) {
|
||||
if (!this.flags.collapsed) {
|
||||
let imgURLs = []
|
||||
let imagesChanged = false
|
||||
|
||||
const output = app.nodeOutputs[this.id + ""];
|
||||
if (output && output.images) {
|
||||
if (this.images !== output.images) {
|
||||
this.images = output.images;
|
||||
this.imgs = null;
|
||||
this.imageIndex = null;
|
||||
imagesChanged = true;
|
||||
imgURLs = imgURLs.concat(output.images.map(params => {
|
||||
return "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
const preview = app.nodePreviewImages[this.id + ""]
|
||||
if (this.preview !== preview) {
|
||||
this.preview = preview
|
||||
imagesChanged = true;
|
||||
if (preview != null) {
|
||||
imgURLs.push(preview);
|
||||
}
|
||||
}
|
||||
|
||||
if (imagesChanged) {
|
||||
this.imageIndex = null;
|
||||
if (imgURLs.length > 0) {
|
||||
Promise.all(
|
||||
output.images.map((src) => {
|
||||
imgURLs.map((src) => {
|
||||
return new Promise((r) => {
|
||||
const img = new Image();
|
||||
img.onload = () => r(img);
|
||||
img.onerror = () => r(null);
|
||||
img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam();
|
||||
img.src = src
|
||||
});
|
||||
})
|
||||
).then((imgs) => {
|
||||
if (this.images === output.images) {
|
||||
if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
|
||||
this.imgs = imgs.filter(Boolean);
|
||||
this.setSizeForImage?.();
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
});
|
||||
}
|
||||
else {
|
||||
this.imgs = null;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.imgs && this.imgs.length) {
|
||||
|
@ -901,17 +930,20 @@ export class ComfyApp {
|
|||
this.progress = null;
|
||||
this.runningNodeId = detail;
|
||||
this.graph.setDirtyCanvas(true, false);
|
||||
delete this.nodePreviewImages[this.runningNodeId]
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node?.onExecuted) {
|
||||
node.onExecuted(detail.output);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
}
|
||||
});
|
||||
|
||||
api.addEventListener("execution_start", ({ detail }) => {
|
||||
this.runningNodeId = null;
|
||||
this.lastExecutionError = null
|
||||
});
|
||||
|
||||
|
@ -922,6 +954,16 @@ export class ComfyApp {
|
|||
this.canvas.draw(true, true);
|
||||
});
|
||||
|
||||
api.addEventListener("b_preview", ({ detail }) => {
|
||||
const id = this.runningNodeId
|
||||
if (id == null)
|
||||
return;
|
||||
|
||||
const blob = detail
|
||||
const blobUrl = URL.createObjectURL(blob)
|
||||
this.nodePreviewImages[id] = [blobUrl]
|
||||
});
|
||||
|
||||
api.init();
|
||||
}
|
||||
|
||||
|
@ -1465,8 +1507,10 @@ export class ComfyApp {
|
|||
*/
|
||||
clean() {
|
||||
this.nodeOutputs = {};
|
||||
this.nodePreviewImages = {}
|
||||
this.lastPromptError = null;
|
||||
this.lastExecutionError = null;
|
||||
this.runningNodeId = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue