Merge branch 'master' into patch_hooks_improved_memory

This commit is contained in:
Jedrzej Kosinski 2024-11-11 11:26:27 -06:00
commit 4195dfb032
27 changed files with 776 additions and 140 deletions

View File

@ -40,6 +40,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.

View File

@ -2,6 +2,7 @@ from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger
class InternalRoutes:
@ -11,7 +12,8 @@ class InternalRoutes:
Check README.md for more information.
'''
def __init__(self):
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
@ -19,6 +21,8 @@ class InternalRoutes:
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
@self.routes.get('/files')
@ -34,7 +38,28 @@ class InternalRoutes:
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})
@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)
return web.Response(status=200)
@self.routes.get('/folder_paths')
async def get_folder_paths(request):

View File

@ -0,0 +1,47 @@
from app.logger import on_flush
import os
class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)
def update_size(self):
sz = os.get_terminal_size()
changed = False
if sz.columns != self.cols:
self.cols = sz.columns
changed = True
if sz.lines != self.rows:
self.rows = sz.lines
changed = True
if changed:
return {"cols": self.cols, "rows": self.rows}
return None
def subscribe(self, client_id):
self.subscriptions.add(client_id)
def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)

View File

@ -1,20 +1,69 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stdout_interceptor = None
stderr_interceptor = None
class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []
def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)
# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)
def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []
def on_flush(self, callback):
self._flush_callbacks.append(callback)
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
return logs
def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return
# Override output streams and log to buffer
logs = deque(maxlen=capacity)
global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)
@ -22,10 +71,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

View File

@ -4,15 +4,31 @@ import re
import uuid
import glob
import shutil
import logging
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
import folder_paths
from .app_settings import AppSettings
from typing import TypedDict
default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
}
class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
@ -154,6 +170,7 @@ class UserManager():
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
split_path = request.rel_url.query.get('split', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
@ -161,26 +178,21 @@ class UserManager():
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info:
return get_file_info(full_path, path)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
if split_path:
return [rel_path] + rel_path.split('/')
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return rel_path
results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results)
@ -208,20 +220,51 @@ class UserManager():
@routes.post("/userdata/{file}")
async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
return web.Response(status=409, text="File already exists")
body = await request.read()
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
return web.json_response(resp)
@routes.delete("/userdata/{file}")
@ -236,6 +279,30 @@ class UserManager():
@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
@ -244,12 +311,19 @@ class UserManager():
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
print(f"moving '{source}' -> '{dest}'")
if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
return web.json_response(resp)

View File

@ -190,7 +190,21 @@ class Mochi(LatentFormat):
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors = None #TODO
self.latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
self.taesd_decoder_name = None #TODO
def process_in(self, latent):

View File

@ -151,8 +151,8 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@ -2,12 +2,16 @@
#adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -158,8 +162,10 @@ class ResBlock(nn.Module):
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
):
super().__init__()
self.channels = channels
@ -170,23 +176,23 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels,
out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
bias=bias,
causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
bias=bias,
causal=causal,
),
)
@ -206,6 +212,81 @@ class ResBlock(nn.Module):
return self.attn_block(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
if self.qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code.
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
class DownsampleBlock(nn.Module):
@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate",
bias=True,
bias=block_kwargs["bias"],
)
)
@ -382,7 +459,7 @@ class Decoder(nn.Module):
blocks = []
first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
@ -452,11 +529,165 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.fourier_features = FourierFeatures()
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.fourier_features(x)
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = None #TODO once the model releases
self.encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
self.decoder = Decoder(
out_channels=3,
base_channels=128,
@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
)
def encode(self, x):
return self.encoder(x)
return self.encoder(x).mode()
def decode(self, x):
return self.decoder(x)

View File

@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if SDP_BATCH_LIMIT >= q.shape[0]:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@ -321,8 +321,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
if scaled_fp8_weight is not None:
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn

View File

@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
@ -48,7 +67,7 @@ class CONST:
return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
def __init__(self, model_config=None, zsnr=None):
super().__init__()
if model_config is not None:
@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
linear_end = sampling_settings.get("linear_end", 0.012)
timesteps = sampling_settings.get("timesteps", 1000)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
if zsnr is None:
zsnr = sampling_settings.get("zsnr", False)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
if given_betas is not None:
betas = given_betas
else:
@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):

View File

@ -3,6 +3,7 @@ import uuid
import torch
import comfy.model_management
import comfy.conds
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
@ -12,12 +13,7 @@ if TYPE_CHECKING:
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
def get_models_from_cond(cond, model_type):
models = []

View File

@ -218,6 +218,7 @@ class VAE:
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@ -287,16 +288,22 @@ class VAE:
self.output_channels = 2
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
self.latent_channels = 12
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -401,24 +408,45 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1,-1)
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")

View File

@ -197,6 +197,8 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict:
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
self.sampling_settings["zsnr"] = True
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS

View File

@ -12,7 +12,7 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):

View File

@ -848,3 +848,24 @@ class ProgressBar:
def update(self, value):
self.update_absolute(self.current + value)
def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2
if dims == 1:
scale_mode = "linear"
if dims == 2:
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "bilinear"
if dims == 3:
if len(input_mask.shape) < 5:
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "trilinear"
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
return mask

View File

@ -3,9 +3,6 @@ import torch
import comfy.model_management
class EmptyMochiLatentVideo:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
@ -15,10 +12,10 @@ class EmptyMochiLatentVideo:
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/mochi"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {

View File

@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
return log_sigma.exp().to(timestep.device)
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
m.add_object_patch("model_sampling", model_sampling)
return (m, )

View File

@ -57,12 +57,24 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
total = mask.shape[-1]
x = round(math.sqrt((lh / lw) * total))
xx = None
for i in range(0, math.floor(math.sqrt(total) / 2)):
for j in [(x + i), max(1, x - i)]:
if total % j == 0:
xx = j
break
if xx is not None:
break
x = xx
y = total // x
# Reshape
mask = (
mask.reshape(b, *mid_shape)
mask.reshape(b, x, y)
.unsqueeze(1)
.type(attn.dtype)
)

View File

@ -7,17 +7,19 @@ import re
class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)

View File

@ -18,7 +18,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
@ -81,7 +81,8 @@ extension_mimetypes_cache = {
}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
legacy = {"unet": "diffusion_models",
"clip": "text_encoders"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory):

View File

@ -47,7 +47,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
if x0.ndim == 5:
x0 = x0[0, :, 0]
else:
x0 = x0[0]
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image)

View File

@ -293,15 +293,21 @@ class VAEDecodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size):
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4:
overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
class VAEEncode:
@classmethod
@ -891,7 +897,7 @@ class UNETLoader:
class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
}}
RETURN_TYPES = ("CLIP",)
@ -899,6 +905,8 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
def load_clip(self, clip_name, type="stable_diffusion"):
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
@ -911,15 +919,15 @@ class CLIPLoader:
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class DualCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
"clip_name2": (folder_paths.get_filename_list("clip"), ),
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux"], ),
}}
RETURN_TYPES = ("CLIP",)
@ -927,9 +935,11 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":

View File

@ -152,7 +152,7 @@ class PromptServer():
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop

View File

@ -14,7 +14,7 @@ def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
) if file else tmp_path
return um
@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt", data=content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was created with correct content
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == content
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
new_content = b"updated content"
resp = await client.post("/userdata/test.txt", data=new_content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was overwritten
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == new_content
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
assert resp.status == 409
# Verify original content unchanged
with open(tmp_path / "test.txt", "r") as f:
assert f.read() == "initial content"
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
assert resp.status == 200
result = await resp.json()
assert result["path"] == "test.txt"
assert result["size"] == len(content)
assert "modified" in result
async def test_move_userdata(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt")
assert resp.status == 200
assert await resp.text() == '"dest.txt"'
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create source and destination files
with open(tmp_path / "source.txt", "w") as f:
f.write("source content")
with open(tmp_path / "dest.txt", "w") as f:
f.write("destination content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
assert resp.status == 409
# Verify files remain unchanged
with open(tmp_path / "source.txt", "r") as f:
assert f.read() == "source content"
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "destination content"
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
assert resp.status == 200
result = await resp.json()
assert result["path"] == "dest.txt"
assert result["size"] == len("test content")
assert "modified" in result
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"

View File

@ -8,7 +8,7 @@ from folder_paths import models_dir, user_directory, output_directory
@pytest.fixture
def internal_routes():
return InternalRoutes()
return InternalRoutes(None)
@pytest.fixture
def aiohttp_client_factory(aiohttp_client, internal_routes):
@ -102,7 +102,7 @@ async def test_file_service_initialization():
# Create a mock instance
mock_file_service_instance = MagicMock(spec=FileService)
MockFileService.return_value = mock_file_service_instance
internal_routes = InternalRoutes()
internal_routes = InternalRoutes(None)
# Check if FileService was initialized with the correct parameters
MockFileService.assert_called_once_with({
@ -112,4 +112,4 @@ async def test_file_service_initialization():
})
# Verify that the file_service attribute of InternalRoutes is set
assert internal_routes.file_service == mock_file_service_instance
assert internal_routes.file_service == mock_file_service_instance