Merge branch 'master' into patch_hooks_improved_memory
This commit is contained in:
commit
4195dfb032
|
@ -40,6 +40,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
|
|
|
@ -2,6 +2,7 @@ from aiohttp import web
|
|||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
|
@ -11,7 +12,8 @@ class InternalRoutes:
|
|||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
|
@ -19,6 +21,8 @@ class InternalRoutes:
|
|||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
|
@ -34,7 +38,28 @@ class InternalRoutes:
|
|||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response(app.logger.get_logs())
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
||||
})
|
||||
|
||||
@self.routes.patch('/logs/subscribe')
|
||||
async def subscribe_logs(request):
|
||||
json_data = await request.json()
|
||||
client_id = json_data["clientId"]
|
||||
enabled = json_data["enabled"]
|
||||
if enabled:
|
||||
self.terminal_service.subscribe(client_id)
|
||||
else:
|
||||
self.terminal_service.unsubscribe(client_id)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from app.logger import on_flush
|
||||
import os
|
||||
|
||||
|
||||
class TerminalService:
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.cols = None
|
||||
self.rows = None
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def update_size(self):
|
||||
sz = os.get_terminal_size()
|
||||
changed = False
|
||||
if sz.columns != self.cols:
|
||||
self.cols = sz.columns
|
||||
changed = True
|
||||
|
||||
if sz.lines != self.rows:
|
||||
self.rows = sz.lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
return {"cols": self.cols, "rows": self.rows}
|
||||
|
||||
return None
|
||||
|
||||
def subscribe(self, client_id):
|
||||
self.subscriptions.add(client_id)
|
||||
|
||||
def unsubscribe(self, client_id):
|
||||
self.subscriptions.discard(client_id)
|
||||
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
self.unsubscribe(client_id)
|
||||
continue
|
||||
|
||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
|
@ -1,20 +1,69 @@
|
|||
import logging
|
||||
from logging.handlers import MemoryHandler
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
logs = None
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
||||
self._lock = threading.Lock()
|
||||
self._flush_callbacks = []
|
||||
self._logs_since_flush = []
|
||||
|
||||
def write(self, data):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
|
||||
def get_logs():
|
||||
return "\n".join([formatter.format(x) for x in logs])
|
||||
return logs
|
||||
|
||||
|
||||
def on_flush(callback):
|
||||
if stdout_interceptor is not None:
|
||||
stdout_interceptor.on_flush(callback)
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Override output streams and log to buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
|
||||
global stdout_interceptor
|
||||
global stderr_interceptor
|
||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
|
@ -22,10 +71,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
|||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
# Create a memory handler with a deque as its buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||
memory_handler.buffer = logs
|
||||
memory_handler.setFormatter(formatter)
|
||||
logger.addHandler(memory_handler)
|
||||
|
|
|
@ -4,15 +4,31 @@ import re
|
|||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
from typing import TypedDict
|
||||
|
||||
default_user = "default"
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
}
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
@ -154,6 +170,7 @@ class UserManager():
|
|||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
|
@ -161,26 +178,21 @@ class UserManager():
|
|||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
results = glob.glob(pattern, recursive=recurse)
|
||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
||||
if full_info:
|
||||
return get_file_info(full_path, path)
|
||||
|
||||
if full_info:
|
||||
results = [
|
||||
{
|
||||
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||
'size': os.path.getsize(x),
|
||||
'modified': os.path.getmtime(x)
|
||||
} for x in results if os.path.isfile(x)
|
||||
]
|
||||
else:
|
||||
results = [
|
||||
os.path.relpath(x, path).replace(os.sep, '/')
|
||||
for x in results
|
||||
if os.path.isfile(x)
|
||||
]
|
||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
||||
if split_path:
|
||||
return [rel_path] + rel_path.split('/')
|
||||
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path and not full_info:
|
||||
results = [[x] + x.split('/') for x in results]
|
||||
return rel_path
|
||||
|
||||
results = [
|
||||
process_full_path(full_path)
|
||||
for full_path in glob.glob(pattern, recursive=recurse)
|
||||
if os.path.isfile(full_path)
|
||||
]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
|
@ -208,20 +220,51 @@ class UserManager():
|
|||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
"""
|
||||
Upload or update a user data file.
|
||||
|
||||
This endpoint handles file uploads to a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Path Parameters:
|
||||
- file: The target file path (URL encoded if necessary).
|
||||
|
||||
Returns:
|
||||
- 400: If 'file' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 409: If overwrite=false and the file already exists.
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
|
||||
The request body should contain the raw file content to be written.
|
||||
"""
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(path, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(path, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
|
@ -236,6 +279,30 @@ class UserManager():
|
|||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
"""
|
||||
Move or rename a user data file.
|
||||
|
||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Path Parameters:
|
||||
- file: The source file path (URL encoded if necessary)
|
||||
- dest: The destination file path (URL encoded if necessary)
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Returns:
|
||||
- 400: If either 'file' or 'dest' parameter is missing
|
||||
- 403: If either requested path is not allowed
|
||||
- 404: If the source file does not exist
|
||||
- 409: If overwrite=false and the destination file already exists
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
"""
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
|
@ -244,12 +311,19 @@ class UserManager():
|
|||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(dest, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(dest, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
|
|
@ -190,7 +190,21 @@ class Mochi(LatentFormat):
|
|||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors = None #TODO
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
|
|
|
@ -151,8 +151,8 @@ class Flux(nn.Module):
|
|||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
|
|
@ -2,12 +2,16 @@
|
|||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
@ -158,8 +162,10 @@ class ResBlock(nn.Module):
|
|||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
padding_mode: str = "replicate",
|
||||
causal: bool = True,
|
||||
prune_bottleneck: bool = False,
|
||||
padding_mode: str,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -170,23 +176,23 @@ class ResBlock(nn.Module):
|
|||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
# causal=causal,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
# causal=causal,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -206,6 +212,81 @@ class ResBlock(nn.Module):
|
|||
return self.attn_block(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute temporal self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
chunk_size: Chunk size for large tensors.
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
if T == 1:
|
||||
# No attention for single frame.
|
||||
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
qkv = self.qkv(x)
|
||||
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||
x = self.out(x)
|
||||
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||
|
||||
# 1D temporal attention.
|
||||
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||
# Output: x with shape [B, num_heads, t, head_dim]
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, p=2, dim=-1)
|
||||
k = F.normalize(k, p=2, dim=-1)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
|
||||
assert x.size(0) == q.size(0)
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
**attn_kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = norm_fn(dim)
|
||||
self.attn = Attention(dim, **attn_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.attn(self.norm(x))
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||
assert has_attention is False #NOTE: if this is ever true add back the attention code.
|
||||
|
||||
attn_block = None #AttentionBlock(channels) if has_attention else None
|
||||
|
||||
return ResBlock(
|
||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
||||
)
|
||||
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
|
@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
|
|||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
# First layer in each block always uses replicate padding
|
||||
padding_mode="replicate",
|
||||
bias=True,
|
||||
bias=block_kwargs["bias"],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -382,7 +459,7 @@ class Decoder(nn.Module):
|
|||
blocks = []
|
||||
|
||||
first_block = [
|
||||
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
|
@ -452,11 +529,165 @@ class Decoder(nn.Module):
|
|||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
class LatentDistribution:
|
||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||
"""Initialize latent distribution.
|
||||
|
||||
Args:
|
||||
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
assert mean.shape == logvar.shape
|
||||
self.mean = mean
|
||||
self.logvar = logvar
|
||||
|
||||
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||
if temperature == 0.0:
|
||||
return self.mean
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||
else:
|
||||
assert noise.device == self.mean.device
|
||||
noise = noise.to(self.mean.dtype)
|
||||
|
||||
if temperature != 1.0:
|
||||
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||
|
||||
# Just Gaussian sample with no scaling of variance.
|
||||
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
latent_dim: int,
|
||||
temporal_reductions: List[int],
|
||||
spatial_reductions: List[int],
|
||||
prune_bottlenecks: List[bool],
|
||||
has_attentions: List[bool],
|
||||
affine: bool = True,
|
||||
bias: bool = True,
|
||||
input_is_conv_1x1: bool = False,
|
||||
padding_mode: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_reductions = temporal_reductions
|
||||
self.spatial_reductions = spatial_reductions
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.fourier_features = FourierFeatures()
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
num_down_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == num_down_blocks + 2
|
||||
|
||||
layers = (
|
||||
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||
if not input_is_conv_1x1
|
||||
else [Conv1x1(in_channels, ch[0])]
|
||||
)
|
||||
|
||||
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||
assert len(has_attentions) == num_down_blocks + 2
|
||||
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||
|
||||
for _ in range(num_res_blocks[0]):
|
||||
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||
prune_bottlenecks = prune_bottlenecks[1:]
|
||||
has_attentions = has_attentions[1:]
|
||||
|
||||
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||
for i in range(num_down_blocks):
|
||||
layer = DownsampleBlock(
|
||||
ch[i],
|
||||
ch[i + 1],
|
||||
num_res_blocks=num_res_blocks[i + 1],
|
||||
temporal_reduction=temporal_reductions[i],
|
||||
spatial_reduction=spatial_reductions[i],
|
||||
prune_bottleneck=prune_bottlenecks[i],
|
||||
has_attention=has_attentions[i],
|
||||
affine=affine,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
layers.append(layer)
|
||||
|
||||
# Additional blocks.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
# Output layers.
|
||||
self.output_norm = norm_fn(ch[-1])
|
||||
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||
|
||||
@property
|
||||
def temporal_downsample(self):
|
||||
return math.prod(self.temporal_reductions)
|
||||
|
||||
@property
|
||||
def spatial_downsample(self):
|
||||
return math.prod(self.spatial_reductions)
|
||||
|
||||
def forward(self, x) -> LatentDistribution:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||
logvar: Shape: [B, latent_dim, t, h, w].
|
||||
"""
|
||||
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||
x = self.fourier_features(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
x = self.output_norm(x)
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.output_proj(x)
|
||||
|
||||
means, logvar = torch.chunk(x, 2, dim=1)
|
||||
|
||||
assert means.ndim == 5
|
||||
assert logvar.shape == means.shape
|
||||
assert means.size(1) == self.latent_dim
|
||||
|
||||
return LatentDistribution(means, logvar)
|
||||
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = None #TODO once the model releases
|
||||
self.encoder = Encoder(
|
||||
in_channels=15,
|
||||
base_channels=64,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
temporal_reductions=[1, 2, 3],
|
||||
spatial_reductions=[2, 2, 2],
|
||||
prune_bottlenecks=[False, False, False, False, False],
|
||||
has_attentions=[False, True, True, True, True],
|
||||
affine=True,
|
||||
bias=True,
|
||||
input_is_conv_1x1=True,
|
||||
padding_mode="replicate"
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
|
@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
|
|||
)
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
return self.encoder(x).mode()
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
|
|
@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||
|
||||
return out
|
||||
|
||||
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||
SDP_BATCH_LIMIT = 2**15
|
||||
else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -321,8 +321,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||
if scaled_fp8_weight is not None:
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
|
|
@ -2,6 +2,25 @@ import torch
|
|||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
|
@ -48,7 +67,7 @@ class CONST:
|
|||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
|
@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
if zsnr is None:
|
||||
zsnr = sampling_settings.get("zsnr", False)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
|
@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
|
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -12,12 +13,7 @@ if TYPE_CHECKING:
|
|||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
|
|
46
comfy/sd.py
46
comfy/sd.py
|
@ -218,6 +218,7 @@ class VAE:
|
|||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
@ -287,16 +288,22 @@ class VAE:
|
|||
self.output_channels = 2
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
|
||||
self.latent_channels = 12
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
|
@ -401,24 +408,45 @@ class VAE:
|
|||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1,-1)
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
|
|
|
@ -197,6 +197,8 @@ class SDXL(supported_models_base.BASE):
|
|||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||
self.sampling_settings["zsnr"] = True
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
|
|
|
@ -12,7 +12,7 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
|||
|
||||
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
|
|
|
@ -848,3 +848,24 @@ class ProgressBar:
|
|||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
if dims == 1:
|
||||
scale_mode = "linear"
|
||||
|
||||
if dims == 2:
|
||||
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "bilinear"
|
||||
|
||||
if dims == 3:
|
||||
if len(input_mask.shape) < 5:
|
||||
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "trilinear"
|
||||
|
||||
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
|
|
@ -3,9 +3,6 @@ import torch
|
|||
import comfy.model_management
|
||||
|
||||
class EmptyMochiLatentVideo:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
|
@ -15,10 +12,10 @@ class EmptyMochiLatentVideo:
|
|||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/mochi"
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
|
|
@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
|
|||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
|
|||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
|
|
@ -57,12 +57,24 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
|||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
total = mask.shape[-1]
|
||||
x = round(math.sqrt((lh / lw) * total))
|
||||
xx = None
|
||||
for i in range(0, math.floor(math.sqrt(total) / 2)):
|
||||
for j in [(x + i), max(1, x - i)]:
|
||||
if total % j == 0:
|
||||
xx = j
|
||||
break
|
||||
if xx is not None:
|
||||
break
|
||||
|
||||
x = xx
|
||||
y = total // x
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
mask.reshape(b, x, y)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
|
|
|
@ -7,17 +7,19 @@ import re
|
|||
class TripleCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
|||
|
||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
|
@ -81,7 +81,8 @@ extension_mimetypes_cache = {
|
|||
}
|
||||
|
||||
def map_legacy(folder_name: str) -> str:
|
||||
legacy = {"unet": "diffusion_models"}
|
||||
legacy = {"unet": "diffusion_models",
|
||||
"clip": "text_encoders"}
|
||||
return legacy.get(folder_name, folder_name)
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
|
|
|
@ -47,7 +47,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
if x0.ndim == 5:
|
||||
x0 = x0[0, :, 0]
|
||||
else:
|
||||
x0 = x0[0]
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
|
||||
return preview_to_image(latent_image)
|
||||
|
|
28
nodes.py
28
nodes.py
|
@ -293,15 +293,21 @@ class VAEDecodeTiled:
|
|||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
||||
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples, tile_size):
|
||||
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
|
||||
def decode(self, vae, samples, tile_size, overlap=64):
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
|
||||
class VAEEncode:
|
||||
@classmethod
|
||||
|
@ -891,7 +897,7 @@ class UNETLoader:
|
|||
class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
|
@ -899,6 +905,8 @@ class CLIPLoader:
|
|||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
if type == "stable_cascade":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
|
@ -911,15 +919,15 @@ class CLIPLoader:
|
|||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
class DualCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("clip"), ),
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
|
@ -927,9 +935,11 @@ class DualCLIPLoader:
|
|||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
if type == "sdxl":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "sd3":
|
||||
|
|
|
@ -152,7 +152,7 @@ class PromptServer():
|
|||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.internal_routes = InternalRoutes()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
|
|
|
@ -14,7 +14,7 @@ def user_manager(tmp_path):
|
|||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
)
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
|
@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
|||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == [
|
||||
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||
]
|
||||
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
|
||||
|
||||
|
||||
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||
|
@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
|||
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||
assert result[0]["path"] == "subdir/file1.txt"
|
||||
|
||||
|
||||
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was created with correct content
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == content
|
||||
|
||||
|
||||
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
new_content = b"updated content"
|
||||
resp = await client.post("/userdata/test.txt", data=new_content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was overwritten
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == new_content
|
||||
|
||||
|
||||
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify original content unchanged
|
||||
with open(tmp_path / "test.txt", "r") as f:
|
||||
assert f.read() == "initial content"
|
||||
|
||||
|
||||
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "test.txt"
|
||||
assert result["size"] == len(content)
|
||||
assert "modified" in result
|
||||
|
||||
|
||||
async def test_move_userdata(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"dest.txt"'
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
|
||||
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create source and destination files
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("source content")
|
||||
with open(tmp_path / "dest.txt", "w") as f:
|
||||
f.write("destination content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify files remain unchanged
|
||||
with open(tmp_path / "source.txt", "r") as f:
|
||||
assert f.read() == "source content"
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "destination content"
|
||||
|
||||
|
||||
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "dest.txt"
|
||||
assert result["size"] == len("test content")
|
||||
assert "modified" in result
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
|
|
@ -8,7 +8,7 @@ from folder_paths import models_dir, user_directory, output_directory
|
|||
|
||||
@pytest.fixture
|
||||
def internal_routes():
|
||||
return InternalRoutes()
|
||||
return InternalRoutes(None)
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||
|
@ -102,7 +102,7 @@ async def test_file_service_initialization():
|
|||
# Create a mock instance
|
||||
mock_file_service_instance = MagicMock(spec=FileService)
|
||||
MockFileService.return_value = mock_file_service_instance
|
||||
internal_routes = InternalRoutes()
|
||||
internal_routes = InternalRoutes(None)
|
||||
|
||||
# Check if FileService was initialized with the correct parameters
|
||||
MockFileService.assert_called_once_with({
|
||||
|
@ -112,4 +112,4 @@ async def test_file_service_initialization():
|
|||
})
|
||||
|
||||
# Verify that the file_service attribute of InternalRoutes is set
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
|
|
Loading…
Reference in New Issue