Allow POST `/userdata/{file}` endpoint to return full file info (#5446)

* Refactor listuserdata

* Full info param

* Add tests

* Fix mock

* Add full_info support for move user file
This commit is contained in:
Chenlei Hu 2024-11-04 13:57:21 -05:00 committed by GitHub
parent 696672905f
commit c49025f01b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 215 additions and 30 deletions

View File

@ -4,15 +4,31 @@ import re
import uuid import uuid
import glob import glob
import shutil import shutil
import logging
from aiohttp import web from aiohttp import web
from urllib import parse from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
import folder_paths import folder_paths
from .app_settings import AppSettings from .app_settings import AppSettings
from typing import TypedDict
default_user = "default" default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
}
class UserManager(): class UserManager():
def __init__(self): def __init__(self):
user_directory = folder_paths.get_user_directory() user_directory = folder_paths.get_user_directory()
@ -154,6 +170,7 @@ class UserManager():
recurse = request.rel_url.query.get('recurse', '').lower() == "true" recurse = request.rel_url.query.get('recurse', '').lower() == "true"
full_info = request.rel_url.query.get('full_info', '').lower() == "true" full_info = request.rel_url.query.get('full_info', '').lower() == "true"
split_path = request.rel_url.query.get('split', '').lower() == "true"
# Use different patterns based on whether we're recursing or not # Use different patterns based on whether we're recursing or not
if recurse: if recurse:
@ -161,26 +178,21 @@ class UserManager():
else: else:
pattern = os.path.join(glob.escape(path), '*') pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse) def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info: if full_info:
results = [ return get_file_info(full_path, path)
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true" rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
if split_path and not full_info: if split_path:
results = [[x] + x.split('/') for x in results] return [rel_path] + rel_path.split('/')
return rel_path
results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results) return web.json_response(results)
@ -208,20 +220,51 @@ class UserManager():
@routes.post("/userdata/{file}") @routes.post("/userdata/{file}")
async def post_userdata(request): async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request) path = get_user_data_path(request)
if not isinstance(path, str): if not isinstance(path, str):
return path return path
overwrite = request.query["overwrite"] != "false" overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(path): if not overwrite and os.path.exists(path):
return web.Response(status=409) return web.Response(status=409, text="File already exists")
body = await request.read() body = await request.read()
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(body) f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None)) user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
return web.json_response(resp) return web.json_response(resp)
@routes.delete("/userdata/{file}") @routes.delete("/userdata/{file}")
@ -236,6 +279,30 @@ class UserManager():
@routes.post("/userdata/{file}/move/{dest}") @routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request): async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True) source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str): if not isinstance(source, str):
return source return source
@ -244,12 +311,19 @@ class UserManager():
if not isinstance(source, str): if not isinstance(source, str):
return dest return dest
overwrite = request.query["overwrite"] != "false" overwrite = request.query.get("overwrite", 'true') != "false"
if not overwrite and os.path.exists(dest): full_info = request.query.get('full_info', 'false').lower() == "true"
return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'") if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest) shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None)) user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
return web.json_response(resp) return web.json_response(resp)

View File

@ -14,7 +14,7 @@ def user_manager(tmp_path):
um = UserManager() um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join( um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file tmp_path, file
) ) if file else tmp_path
return um return um
@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app) client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true") resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200 assert resp.status == 200
assert await resp.json() == [ assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
["subdir/file1.txt", "subdir", "file1.txt"]
]
async def test_listuserdata_invalid_directory(aiohttp_client, app): async def test_listuserdata_invalid_directory(aiohttp_client, app):
@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
assert "/" in result[0]["path"] # Ensure forward slash is used assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt" assert result[0]["path"] == "subdir/file1.txt"
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt", data=content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was created with correct content
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == content
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
new_content = b"updated content"
resp = await client.post("/userdata/test.txt", data=new_content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was overwritten
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == new_content
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
assert resp.status == 409
# Verify original content unchanged
with open(tmp_path / "test.txt", "r") as f:
assert f.read() == "initial content"
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
assert resp.status == 200
result = await resp.json()
assert result["path"] == "test.txt"
assert result["size"] == len(content)
assert "modified" in result
async def test_move_userdata(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt")
assert resp.status == 200
assert await resp.text() == '"dest.txt"'
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create source and destination files
with open(tmp_path / "source.txt", "w") as f:
f.write("source content")
with open(tmp_path / "dest.txt", "w") as f:
f.write("destination content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
assert resp.status == 409
# Verify files remain unchanged
with open(tmp_path / "source.txt", "r") as f:
assert f.read() == "source content"
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "destination content"
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
assert resp.status == 200
result = await resp.json()
assert result["path"] == "dest.txt"
assert result["size"] == len("test content")
assert "modified" in result
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"