Merge remote-tracking branch 'origin/master'

This commit is contained in:
kallen 2024-11-05 23:00:49 +08:00
commit 03eae57945
5 changed files with 219 additions and 33 deletions

View File

@ -40,6 +40,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.

View File

@ -4,15 +4,31 @@ import re
import uuid
import glob
import shutil
import logging
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
import folder_paths
from .app_settings import AppSettings
from typing import TypedDict
default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
}
class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
@ -154,6 +170,7 @@ class UserManager():
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
split_path = request.rel_url.query.get('split', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
@ -161,26 +178,21 @@ class UserManager():
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info:
return get_file_info(full_path, path)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
if split_path:
return [rel_path] + rel_path.split('/')
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return rel_path
results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results)
@ -208,20 +220,51 @@ class UserManager():
@routes.post("/userdata/{file}")
async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
return web.Response(status=409, text="File already exists")
body = await request.read()
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
return web.json_response(resp)
@routes.delete("/userdata/{file}")
@ -236,6 +279,30 @@ class UserManager():
@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
@ -244,12 +311,19 @@ class UserManager():
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
print(f"moving '{source}' -> '{dest}'")
if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
return web.json_response(resp)

View File

@ -151,8 +151,8 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@ -245,7 +245,7 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:

View File

@ -14,7 +14,7 @@ def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
) if file else tmp_path
return um
@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt", data=content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was created with correct content
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == content
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
new_content = b"updated content"
resp = await client.post("/userdata/test.txt", data=new_content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was overwritten
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == new_content
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
assert resp.status == 409
# Verify original content unchanged
with open(tmp_path / "test.txt", "r") as f:
assert f.read() == "initial content"
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
assert resp.status == 200
result = await resp.json()
assert result["path"] == "test.txt"
assert result["size"] == len(content)
assert "modified" in result
async def test_move_userdata(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt")
assert resp.status == 200
assert await resp.text() == '"dest.txt"'
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create source and destination files
with open(tmp_path / "source.txt", "w") as f:
f.write("source content")
with open(tmp_path / "dest.txt", "w") as f:
f.write("destination content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
assert resp.status == 409
# Verify files remain unchanged
with open(tmp_path / "source.txt", "r") as f:
assert f.read() == "source content"
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "destination content"
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
assert resp.status == 200
result = await resp.json()
assert result["path"] == "dest.txt"
assert result["size"] == len("test content")
assert "modified" in result
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"