Merge branch 'taesd-preview' of https://github.com/space-nuko/ComfyUI
This commit is contained in:
commit
081134f5c8
|
@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
|
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||||
|
@ -181,6 +182,10 @@ You can set this command line setting to disable the upcasting to fp32 in some c
|
||||||
|
|
||||||
```--dont-upcast-attention```
|
```--dont-upcast-attention```
|
||||||
|
|
||||||
|
## How to show high-quality previews?
|
||||||
|
|
||||||
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/taesd` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||||
|
|
|
@ -1,4 +1,35 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class EnumAction(argparse.Action):
|
||||||
|
"""
|
||||||
|
Argparse action for handling Enums
|
||||||
|
"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# Pop off the type value
|
||||||
|
enum_type = kwargs.pop("type", None)
|
||||||
|
|
||||||
|
# Ensure an Enum subclass is provided
|
||||||
|
if enum_type is None:
|
||||||
|
raise ValueError("type must be assigned an Enum when using EnumAction")
|
||||||
|
if not issubclass(enum_type, enum.Enum):
|
||||||
|
raise TypeError("type must be an Enum when using EnumAction")
|
||||||
|
|
||||||
|
# Generate choices from the Enum
|
||||||
|
choices = tuple(e.value for e in enum_type)
|
||||||
|
kwargs.setdefault("choices", choices)
|
||||||
|
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
|
||||||
|
|
||||||
|
super(EnumAction, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self._enum = enum_type
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
# Convert value back into an Enum
|
||||||
|
value = self._enum(values)
|
||||||
|
setattr(namespace, self.dest, value)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
@ -13,6 +44,13 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
||||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
|
class LatentPreviewMethod(enum.Enum):
|
||||||
|
Auto = "auto"
|
||||||
|
Latent2RGB = "latent2rgb"
|
||||||
|
TAESD = "taesd"
|
||||||
|
parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.")
|
||||||
|
parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tiny AutoEncoder for Stable Diffusion
|
||||||
|
(DNN for encoding / decoding SD's latent space)
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def conv(n_in, n_out, **kwargs):
|
||||||
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||||
|
|
||||||
|
class Clamp(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.tanh(x / 3) * 3
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, n_in, n_out):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||||
|
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||||
|
self.fuse = nn.ReLU()
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|
||||||
|
def Encoder():
|
||||||
|
return nn.Sequential(
|
||||||
|
conv(3, 64), Block(64, 64),
|
||||||
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
|
conv(64, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
def Decoder():
|
||||||
|
return nn.Sequential(
|
||||||
|
Clamp(), conv(4, 64), nn.ReLU(),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), conv(64, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
class TAESD(nn.Module):
|
||||||
|
latent_magnitude = 3
|
||||||
|
latent_shift = 0.5
|
||||||
|
|
||||||
|
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
||||||
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder()
|
||||||
|
self.decoder = Decoder()
|
||||||
|
if encoder_path is not None:
|
||||||
|
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu"))
|
||||||
|
if decoder_path is not None:
|
||||||
|
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scale_latents(x):
|
||||||
|
"""raw latents -> [0, 1]"""
|
||||||
|
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unscale_latents(x):
|
||||||
|
"""[0, 1] -> raw latents"""
|
||||||
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
|
@ -166,6 +167,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
|
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
|
||||||
for y in range(0, s.shape[2], tile_y - overlap):
|
for y in range(0, s.shape[2], tile_y - overlap):
|
||||||
for x in range(0, s.shape[3], tile_x - overlap):
|
for x in range(0, s.shape[3], tile_x - overlap):
|
||||||
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||||
|
|
||||||
ps = function(s_in).cpu()
|
ps = function(s_in).cpu()
|
||||||
|
@ -197,14 +200,14 @@ class ProgressBar:
|
||||||
self.current = 0
|
self.current = 0
|
||||||
self.hook = PROGRESS_BAR_HOOK
|
self.hook = PROGRESS_BAR_HOOK
|
||||||
|
|
||||||
def update_absolute(self, value, total=None):
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
self.total = total
|
self.total = total
|
||||||
if value > self.total:
|
if value > self.total:
|
||||||
value = self.total
|
value = self.total
|
||||||
self.current = value
|
self.current = value
|
||||||
if self.hook is not None:
|
if self.hook is not None:
|
||||||
self.hook(self.current, self.total)
|
self.hook(self.current, self.total, preview)
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
|
|
@ -18,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision"
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||||
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||||
|
folder_names_and_paths["taesd"] = ([os.path.join(models_dir, "taesd")], supported_pt_extensions)
|
||||||
|
|
||||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||||
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
||||||
|
|
5
main.py
5
main.py
|
@ -26,6 +26,7 @@ import yaml
|
||||||
import execution
|
import execution
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import server
|
import server
|
||||||
|
from server import BinaryEventTypes
|
||||||
from nodes import init_custom_nodes
|
from nodes import init_custom_nodes
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,8 +41,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
def hook(value, total):
|
def hook(value, total, preview_image_bytes):
|
||||||
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
|
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
|
||||||
|
if preview_image_bytes is not None:
|
||||||
|
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
|
||||||
def cleanup_temp():
|
def cleanup_temp():
|
||||||
|
|
104
nodes.py
104
nodes.py
|
@ -7,6 +7,8 @@ import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import struct
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
@ -22,6 +24,8 @@ import comfy.samplers
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy.cli_args import args, LatentPreviewMethod
|
||||||
|
from comfy.taesd.taesd import TAESD
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
|
@ -31,6 +35,32 @@ import importlib
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
|
class LatentPreviewer:
|
||||||
|
def decode_latent_to_preview(self, device, x0):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
|
def __init__(self):
|
||||||
|
self.latent_rgb_factors = torch.tensor([
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
], device="cpu")
|
||||||
|
|
||||||
|
def decode_latent_to_preview(self, device, x0):
|
||||||
|
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
||||||
|
|
||||||
|
latents_ubyte = (((latent_image + 1) / 2)
|
||||||
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
.byte()).cpu()
|
||||||
|
|
||||||
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
|
@ -38,6 +68,7 @@ def interrupt_processing(value=True):
|
||||||
comfy.model_management.interrupt_current_processing(value)
|
comfy.model_management.interrupt_current_processing(value)
|
||||||
|
|
||||||
MAX_RESOLUTION=8192
|
MAX_RESOLUTION=8192
|
||||||
|
MAX_PREVIEW_RESOLUTION = 512
|
||||||
|
|
||||||
class CLIPTextEncode:
|
class CLIPTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -248,6 +279,21 @@ class VAEEncodeForInpaint:
|
||||||
|
|
||||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||||
|
|
||||||
|
class TAESDPreviewerImpl(LatentPreviewer):
|
||||||
|
def __init__(self, taesd):
|
||||||
|
self.taesd = taesd
|
||||||
|
|
||||||
|
def decode_latent_to_preview(self, device, x0):
|
||||||
|
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
|
||||||
|
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||||
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
|
|
||||||
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
|
preview_image = Image.fromarray(x_sample)
|
||||||
|
return preview_image
|
||||||
|
|
||||||
class SaveLatent:
|
class SaveLatent:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -931,6 +977,26 @@ class SetLatentNoiseMask:
|
||||||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_latent_to_preview_image(previewer, device, preview_format, x0):
|
||||||
|
preview_image = previewer.decode_latent_to_preview(device, x0)
|
||||||
|
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
|
||||||
|
|
||||||
|
preview_type = 1
|
||||||
|
if preview_format == "JPEG":
|
||||||
|
preview_type = 1
|
||||||
|
elif preview_format == "PNG":
|
||||||
|
preview_type = 2
|
||||||
|
|
||||||
|
bytesIO = BytesIO()
|
||||||
|
header = struct.pack(">I", preview_type)
|
||||||
|
bytesIO.write(header)
|
||||||
|
preview_image.save(bytesIO, format=preview_format)
|
||||||
|
preview_bytes = bytesIO.getvalue()
|
||||||
|
|
||||||
|
return preview_bytes
|
||||||
|
|
||||||
|
|
||||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
latent_image = latent["samples"]
|
latent_image = latent["samples"]
|
||||||
|
@ -945,9 +1011,39 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
noise_mask = latent["noise_mask"]
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
|
preview_format = "JPEG"
|
||||||
|
if preview_format not in ["JPEG", "PNG"]:
|
||||||
|
preview_format = "JPEG"
|
||||||
|
|
||||||
|
previewer = None
|
||||||
|
if not args.disable_previews:
|
||||||
|
# TODO previewer methods
|
||||||
|
taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth")
|
||||||
|
taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth")
|
||||||
|
|
||||||
|
method = args.default_preview_method
|
||||||
|
|
||||||
|
if method == LatentPreviewMethod.Auto:
|
||||||
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
|
if taesd_encoder_path and taesd_encoder_path:
|
||||||
|
method = LatentPreviewMethod.TAESD
|
||||||
|
|
||||||
|
if method == LatentPreviewMethod.TAESD:
|
||||||
|
if taesd_encoder_path and taesd_encoder_path:
|
||||||
|
taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device)
|
||||||
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
|
else:
|
||||||
|
print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth")
|
||||||
|
|
||||||
|
if previewer is None:
|
||||||
|
previewer = Latent2RGBPreviewer()
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
def callback(step, x0, x, total_steps):
|
def callback(step, x0, x, total_steps):
|
||||||
pbar.update_absolute(step + 1, total_steps)
|
preview_bytes = None
|
||||||
|
if previewer:
|
||||||
|
preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
|
||||||
|
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||||
|
|
||||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||||
|
@ -970,7 +1066,8 @@ class KSampler:
|
||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
"latent_image": ("LATENT", ),
|
"latent_image": ("LATENT", ),
|
||||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "sample"
|
FUNCTION = "sample"
|
||||||
|
@ -997,7 +1094,8 @@ class KSamplerAdvanced:
|
||||||
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||||
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||||
"return_with_leftover_noise": (["disable", "enable"], ),
|
"return_with_leftover_noise": (["disable", "enable"], ),
|
||||||
}}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "sample"
|
FUNCTION = "sample"
|
||||||
|
|
39
server.py
39
server.py
|
@ -7,6 +7,7 @@ import execution
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
|
import struct
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
@ -25,6 +26,11 @@ from comfy.cli_args import args
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryEventTypes:
|
||||||
|
PREVIEW_IMAGE = 1
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
|
@ -457,16 +463,37 @@ class PromptServer():
|
||||||
return prompt_info
|
return prompt_info
|
||||||
|
|
||||||
async def send(self, event, data, sid=None):
|
async def send(self, event, data, sid=None):
|
||||||
message = {"type": event, "data": data}
|
if isinstance(data, (bytes, bytearray)):
|
||||||
|
await self.send_bytes(event, data, sid)
|
||||||
if isinstance(message, str) == False:
|
else:
|
||||||
message = json.dumps(message)
|
await self.send_json(event, data, sid)
|
||||||
|
|
||||||
|
def encode_bytes(self, event, data):
|
||||||
|
if not isinstance(event, int):
|
||||||
|
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||||
|
|
||||||
|
packed = struct.pack(">I", event)
|
||||||
|
message = bytearray(packed)
|
||||||
|
message.extend(data)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def send_bytes(self, event, data, sid=None):
|
||||||
|
message = self.encode_bytes(event, data)
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
for ws in self.sockets.values():
|
||||||
await ws.send_str(message)
|
await ws.send_bytes(message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await self.sockets[sid].send_str(message)
|
await self.sockets[sid].send_bytes(message)
|
||||||
|
|
||||||
|
async def send_json(self, event, data, sid=None):
|
||||||
|
message = {"type": event, "data": data}
|
||||||
|
|
||||||
|
if sid is None:
|
||||||
|
for ws in self.sockets.values():
|
||||||
|
await ws.send_json(message)
|
||||||
|
elif sid in self.sockets:
|
||||||
|
await self.sockets[sid].send_json(message)
|
||||||
|
|
||||||
def send_sync(self, event, data, sid=None):
|
def send_sync(self, event, data, sid=None):
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
|
|
|
@ -21,6 +21,7 @@ const colorPalettes = {
|
||||||
"MODEL": "#B39DDB", // light lavender-purple
|
"MODEL": "#B39DDB", // light lavender-purple
|
||||||
"STYLE_MODEL": "#C2FFAE", // light green-yellow
|
"STYLE_MODEL": "#C2FFAE", // light green-yellow
|
||||||
"VAE": "#FF6E6E", // bright red
|
"VAE": "#FF6E6E", // bright red
|
||||||
|
"TAESD": "#DCC274", // cheesecake
|
||||||
},
|
},
|
||||||
"litegraph_base": {
|
"litegraph_base": {
|
||||||
"NODE_TITLE_COLOR": "#999",
|
"NODE_TITLE_COLOR": "#999",
|
||||||
|
|
|
@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
|
||||||
this.socket = new WebSocket(
|
this.socket = new WebSocket(
|
||||||
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
|
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
|
||||||
);
|
);
|
||||||
|
this.socket.binaryType = "arraybuffer";
|
||||||
|
|
||||||
this.socket.addEventListener("open", () => {
|
this.socket.addEventListener("open", () => {
|
||||||
opened = true;
|
opened = true;
|
||||||
|
@ -70,39 +71,65 @@ class ComfyApi extends EventTarget {
|
||||||
|
|
||||||
this.socket.addEventListener("message", (event) => {
|
this.socket.addEventListener("message", (event) => {
|
||||||
try {
|
try {
|
||||||
const msg = JSON.parse(event.data);
|
if (event.data instanceof ArrayBuffer) {
|
||||||
switch (msg.type) {
|
const view = new DataView(event.data);
|
||||||
case "status":
|
const eventType = view.getUint32(0);
|
||||||
if (msg.data.sid) {
|
const buffer = event.data.slice(4);
|
||||||
this.clientId = msg.data.sid;
|
switch (eventType) {
|
||||||
window.name = this.clientId;
|
case 1:
|
||||||
|
const view2 = new DataView(event.data);
|
||||||
|
const imageType = view2.getUint32(0)
|
||||||
|
let imageMime
|
||||||
|
switch (imageType) {
|
||||||
|
case 1:
|
||||||
|
default:
|
||||||
|
imageMime = "image/jpeg";
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
imageMime = "image/png"
|
||||||
}
|
}
|
||||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
const imageBlob = new Blob([buffer.slice(4)], { type: imageMime });
|
||||||
break;
|
this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
|
||||||
case "progress":
|
|
||||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
|
||||||
break;
|
|
||||||
case "executing":
|
|
||||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
|
||||||
break;
|
|
||||||
case "executed":
|
|
||||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
|
||||||
break;
|
|
||||||
case "execution_start":
|
|
||||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
|
||||||
break;
|
|
||||||
case "execution_error":
|
|
||||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
if (this.#registered.has(msg.type)) {
|
throw new Error(`Unknown binary websocket message of type ${eventType}`);
|
||||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
}
|
||||||
} else {
|
}
|
||||||
throw new Error("Unknown message type");
|
else {
|
||||||
}
|
const msg = JSON.parse(event.data);
|
||||||
|
switch (msg.type) {
|
||||||
|
case "status":
|
||||||
|
if (msg.data.sid) {
|
||||||
|
this.clientId = msg.data.sid;
|
||||||
|
window.name = this.clientId;
|
||||||
|
}
|
||||||
|
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||||
|
break;
|
||||||
|
case "progress":
|
||||||
|
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||||
|
break;
|
||||||
|
case "executing":
|
||||||
|
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||||
|
break;
|
||||||
|
case "executed":
|
||||||
|
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||||
|
break;
|
||||||
|
case "execution_start":
|
||||||
|
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||||
|
break;
|
||||||
|
case "execution_error":
|
||||||
|
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
if (this.#registered.has(msg.type)) {
|
||||||
|
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||||
|
} else {
|
||||||
|
throw new Error(`Unknown message type ${msg.type}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.warn("Unhandled message:", event.data);
|
console.warn("Unhandled message:", event.data, error);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,12 @@ export class ComfyApp {
|
||||||
*/
|
*/
|
||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the preview image data for each node
|
||||||
|
* @type {Record<string, Image>}
|
||||||
|
*/
|
||||||
|
this.nodePreviewImages = {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If the shift key on the keyboard is pressed
|
* If the shift key on the keyboard is pressed
|
||||||
* @type {boolean}
|
* @type {boolean}
|
||||||
|
@ -367,29 +373,52 @@ export class ComfyApp {
|
||||||
|
|
||||||
node.prototype.onDrawBackground = function (ctx) {
|
node.prototype.onDrawBackground = function (ctx) {
|
||||||
if (!this.flags.collapsed) {
|
if (!this.flags.collapsed) {
|
||||||
|
let imgURLs = []
|
||||||
|
let imagesChanged = false
|
||||||
|
|
||||||
const output = app.nodeOutputs[this.id + ""];
|
const output = app.nodeOutputs[this.id + ""];
|
||||||
if (output && output.images) {
|
if (output && output.images) {
|
||||||
if (this.images !== output.images) {
|
if (this.images !== output.images) {
|
||||||
this.images = output.images;
|
this.images = output.images;
|
||||||
this.imgs = null;
|
imagesChanged = true;
|
||||||
this.imageIndex = null;
|
imgURLs = imgURLs.concat(output.images.map(params => {
|
||||||
|
return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam();
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const preview = app.nodePreviewImages[this.id + ""]
|
||||||
|
if (this.preview !== preview) {
|
||||||
|
this.preview = preview
|
||||||
|
imagesChanged = true;
|
||||||
|
if (preview != null) {
|
||||||
|
imgURLs.push(preview);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesChanged) {
|
||||||
|
this.imageIndex = null;
|
||||||
|
if (imgURLs.length > 0) {
|
||||||
Promise.all(
|
Promise.all(
|
||||||
output.images.map((src) => {
|
imgURLs.map((src) => {
|
||||||
return new Promise((r) => {
|
return new Promise((r) => {
|
||||||
const img = new Image();
|
const img = new Image();
|
||||||
img.onload = () => r(img);
|
img.onload = () => r(img);
|
||||||
img.onerror = () => r(null);
|
img.onerror = () => r(null);
|
||||||
img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam();
|
img.src = src
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
).then((imgs) => {
|
).then((imgs) => {
|
||||||
if (this.images === output.images) {
|
if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
|
||||||
this.imgs = imgs.filter(Boolean);
|
this.imgs = imgs.filter(Boolean);
|
||||||
this.setSizeForImage?.();
|
this.setSizeForImage?.();
|
||||||
app.graph.setDirtyCanvas(true);
|
app.graph.setDirtyCanvas(true);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
this.imgs = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.imgs && this.imgs.length) {
|
if (this.imgs && this.imgs.length) {
|
||||||
|
@ -901,17 +930,20 @@ export class ComfyApp {
|
||||||
this.progress = null;
|
this.progress = null;
|
||||||
this.runningNodeId = detail;
|
this.runningNodeId = detail;
|
||||||
this.graph.setDirtyCanvas(true, false);
|
this.graph.setDirtyCanvas(true, false);
|
||||||
|
delete this.nodePreviewImages[this.runningNodeId]
|
||||||
});
|
});
|
||||||
|
|
||||||
api.addEventListener("executed", ({ detail }) => {
|
api.addEventListener("executed", ({ detail }) => {
|
||||||
this.nodeOutputs[detail.node] = detail.output;
|
this.nodeOutputs[detail.node] = detail.output;
|
||||||
const node = this.graph.getNodeById(detail.node);
|
const node = this.graph.getNodeById(detail.node);
|
||||||
if (node?.onExecuted) {
|
if (node) {
|
||||||
node.onExecuted(detail.output);
|
if (node.onExecuted)
|
||||||
|
node.onExecuted(detail.output);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
api.addEventListener("execution_start", ({ detail }) => {
|
api.addEventListener("execution_start", ({ detail }) => {
|
||||||
|
this.runningNodeId = null;
|
||||||
this.lastExecutionError = null
|
this.lastExecutionError = null
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -922,6 +954,16 @@ export class ComfyApp {
|
||||||
this.canvas.draw(true, true);
|
this.canvas.draw(true, true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
api.addEventListener("b_preview", ({ detail }) => {
|
||||||
|
const id = this.runningNodeId
|
||||||
|
if (id == null)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const blob = detail
|
||||||
|
const blobUrl = URL.createObjectURL(blob)
|
||||||
|
this.nodePreviewImages[id] = [blobUrl]
|
||||||
|
});
|
||||||
|
|
||||||
api.init();
|
api.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1465,8 +1507,10 @@ export class ComfyApp {
|
||||||
*/
|
*/
|
||||||
clean() {
|
clean() {
|
||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
|
this.nodePreviewImages = {}
|
||||||
this.lastPromptError = null;
|
this.lastPromptError = null;
|
||||||
this.lastExecutionError = null;
|
this.lastExecutionError = null;
|
||||||
|
this.runningNodeId = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue