This commit is contained in:
comfyanonymous 2023-06-05 23:53:36 -04:00
commit 081134f5c8
12 changed files with 359 additions and 47 deletions

View File

@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models. - [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@ -181,6 +182,10 @@ You can set this command line setting to disable the upcasting to fp32 in some c
```--dont-upcast-attention``` ```--dont-upcast-attention```
## How to show high-quality previews?
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/taesd` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## Support and dev channel ## Support and dev channel
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). [Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).

View File

@ -1,4 +1,35 @@
import argparse import argparse
import enum
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -13,6 +44,13 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class LatentPreviewMethod(enum.Enum):
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.")
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")

65
comfy/taesd/taesd.py Normal file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import torch
import torch.nn as nn
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def Encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)
def Decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"))
@staticmethod
def scale_latents(x):
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

View File

@ -1,6 +1,7 @@
import torch import torch
import math import math
import struct import struct
import comfy.model_management
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
@ -166,6 +167,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap): for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap): for x in range(0, s.shape[3], tile_x - overlap):
comfy.model_management.throw_exception_if_processing_interrupted()
s_in = s[:,:,y:y+tile_y,x:x+tile_x] s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu() ps = function(s_in).cpu()
@ -197,14 +200,14 @@ class ProgressBar:
self.current = 0 self.current = 0
self.hook = PROGRESS_BAR_HOOK self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
self.total = total self.total = total
if value > self.total: if value > self.total:
value = self.total value = self.total
self.current = value self.current = value
if self.hook is not None: if self.hook is not None:
self.hook(self.current, self.total) self.hook(self.current, self.total, preview)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -18,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision"
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["taesd"] = ([os.path.join(models_dir, "taesd")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)

View File

@ -26,6 +26,7 @@ import yaml
import execution import execution
import folder_paths import folder_paths
import server import server
from server import BinaryEventTypes
from nodes import init_custom_nodes from nodes import init_custom_nodes
@ -40,8 +41,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
def hook(value, total): def hook(value, total, preview_image_bytes):
server.send_sync("progress", { "value": value, "max": total}, server.client_id) server.send_sync("progress", { "value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp(): def cleanup_temp():

104
nodes.py
View File

@ -7,6 +7,8 @@ import hashlib
import traceback import traceback
import math import math
import time import time
import struct
from io import BytesIO
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -22,6 +24,8 @@ import comfy.samplers
import comfy.sample import comfy.sample
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.clip_vision import comfy.clip_vision
@ -31,6 +35,32 @@ import importlib
import folder_paths import folder_paths
class LatentPreviewer:
def decode_latent_to_preview(self, device, x0):
pass
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def decode_latent_to_preview(self, device, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -38,6 +68,7 @@ def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=8192
MAX_PREVIEW_RESOLUTION = 512
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -248,6 +279,21 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
def decode_latent_to_preview(self, device, x0):
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
preview_image = Image.fromarray(x_sample)
return preview_image
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
@ -931,6 +977,26 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)
def decode_latent_to_preview_image(previewer, device, preview_format, x0):
preview_image = previewer.decode_latent_to_preview(device, x0)
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format)
preview_bytes = bytesIO.getvalue()
return preview_bytes
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -945,9 +1011,39 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = None
if not args.disable_previews:
# TODO previewer methods
taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth")
taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth")
method = args.default_preview_method
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if taesd_encoder_path and taesd_encoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_encoder_path and taesd_encoder_path:
taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth")
if previewer is None:
previewer = Latent2RGBPreviewer()
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps) preview_bytes = None
if previewer:
preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
@ -970,7 +1066,8 @@ class KSampler:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ), "latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} }
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"
@ -997,7 +1094,8 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ), "return_with_leftover_noise": (["disable", "enable"], ),
}} }
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"

View File

@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import struct
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -25,6 +26,11 @@ from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request) response: web.Response = await handler(request)
@ -457,16 +463,37 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
message = {"type": event, "data": data} if isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
if isinstance(message, str) == False: else:
message = json.dumps(message) await self.send_json(event, data, sid)
def encode_bytes(self, event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): for ws in self.sockets.values():
await ws.send_str(message) await ws.send_bytes(message)
elif sid in self.sockets: elif sid in self.sockets:
await self.sockets[sid].send_str(message) await self.sockets[sid].send_bytes(message)
async def send_json(self, event, data, sid=None):
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
await ws.send_json(message)
elif sid in self.sockets:
await self.sockets[sid].send_json(message)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(

View File

@ -21,6 +21,7 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple "MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow "STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red "VAE": "#FF6E6E", // bright red
"TAESD": "#DCC274", // cheesecake
}, },
"litegraph_base": { "litegraph_base": {
"NODE_TITLE_COLOR": "#999", "NODE_TITLE_COLOR": "#999",

View File

@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this.socket = new WebSocket( this.socket = new WebSocket(
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
); );
this.socket.binaryType = "arraybuffer";
this.socket.addEventListener("open", () => { this.socket.addEventListener("open", () => {
opened = true; opened = true;
@ -70,39 +71,65 @@ class ComfyApi extends EventTarget {
this.socket.addEventListener("message", (event) => { this.socket.addEventListener("message", (event) => {
try { try {
const msg = JSON.parse(event.data); if (event.data instanceof ArrayBuffer) {
switch (msg.type) { const view = new DataView(event.data);
case "status": const eventType = view.getUint32(0);
if (msg.data.sid) { const buffer = event.data.slice(4);
this.clientId = msg.data.sid; switch (eventType) {
window.name = this.clientId; case 1:
const view2 = new DataView(event.data);
const imageType = view2.getUint32(0)
let imageMime
switch (imageType) {
case 1:
default:
imageMime = "image/jpeg";
break;
case 2:
imageMime = "image/png"
} }
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); const imageBlob = new Blob([buffer.slice(4)], { type: imageMime });
break; this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break; break;
default: default:
if (this.#registered.has(msg.type)) { throw new Error(`Unknown binary websocket message of type ${eventType}`);
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); }
} else { }
throw new Error("Unknown message type"); else {
} const msg = JSON.parse(event.data);
switch (msg.type) {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
window.name = this.clientId;
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else {
throw new Error(`Unknown message type ${msg.type}`);
}
}
} }
} catch (error) { } catch (error) {
console.warn("Unhandled message:", event.data); console.warn("Unhandled message:", event.data, error);
} }
}); });
} }

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/ */
this.nodeOutputs = {}; this.nodeOutputs = {};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this.nodePreviewImages = {};
/** /**
* If the shift key on the keyboard is pressed * If the shift key on the keyboard is pressed
* @type {boolean} * @type {boolean}
@ -367,29 +373,52 @@ export class ComfyApp {
node.prototype.onDrawBackground = function (ctx) { node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) { if (!this.flags.collapsed) {
let imgURLs = []
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
if (output && output.images) { if (output && output.images) {
if (this.images !== output.images) { if (this.images !== output.images) {
this.images = output.images; this.images = output.images;
this.imgs = null; imagesChanged = true;
this.imageIndex = null; imgURLs = imgURLs.concat(output.images.map(params => {
return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam();
}))
}
}
const preview = app.nodePreviewImages[this.id + ""]
if (this.preview !== preview) {
this.preview = preview
imagesChanged = true;
if (preview != null) {
imgURLs.push(preview);
}
}
if (imagesChanged) {
this.imageIndex = null;
if (imgURLs.length > 0) {
Promise.all( Promise.all(
output.images.map((src) => { imgURLs.map((src) => {
return new Promise((r) => { return new Promise((r) => {
const img = new Image(); const img = new Image();
img.onload = () => r(img); img.onload = () => r(img);
img.onerror = () => r(null); img.onerror = () => r(null);
img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); img.src = src
}); });
}) })
).then((imgs) => { ).then((imgs) => {
if (this.images === output.images) { if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
this.imgs = imgs.filter(Boolean); this.imgs = imgs.filter(Boolean);
this.setSizeForImage?.(); this.setSizeForImage?.();
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
} }
}); });
} }
else {
this.imgs = null;
}
} }
if (this.imgs && this.imgs.length) { if (this.imgs && this.imgs.length) {
@ -901,17 +930,20 @@ export class ComfyApp {
this.progress = null; this.progress = null;
this.runningNodeId = detail; this.runningNodeId = detail;
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId]
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output; this.nodeOutputs[detail.node] = detail.output;
const node = this.graph.getNodeById(detail.node); const node = this.graph.getNodeById(detail.node);
if (node?.onExecuted) { if (node) {
node.onExecuted(detail.output); if (node.onExecuted)
node.onExecuted(detail.output);
} }
}); });
api.addEventListener("execution_start", ({ detail }) => { api.addEventListener("execution_start", ({ detail }) => {
this.runningNodeId = null;
this.lastExecutionError = null this.lastExecutionError = null
}); });
@ -922,6 +954,16 @@ export class ComfyApp {
this.canvas.draw(true, true); this.canvas.draw(true, true);
}); });
api.addEventListener("b_preview", ({ detail }) => {
const id = this.runningNodeId
if (id == null)
return;
const blob = detail
const blobUrl = URL.createObjectURL(blob)
this.nodePreviewImages[id] = [blobUrl]
});
api.init(); api.init();
} }
@ -1465,8 +1507,10 @@ export class ComfyApp {
*/ */
clean() { clean() {
this.nodeOutputs = {}; this.nodeOutputs = {};
this.nodePreviewImages = {}
this.lastPromptError = null; this.lastPromptError = null;
this.lastExecutionError = null; this.lastExecutionError = null;
this.runningNodeId = null;
} }
} }