Add `FrontendManager` to manage non-default front-end impl (#3897)
* Add frontend manager * Add tests * nit * Add unit test to github CI * Fix path * nit * ignore * Add logging * Install test deps * Remove 'stable' keyword support * Update test * Add web-root arg * Rename web-root to front-end-root * Add test on non-exist version number * Use repo owner/name to replace hard coded provider list * Inline cmd args * nit * Fix unit test
This commit is contained in:
parent
33346fd9b8
commit
99458e8aca
|
@ -24,3 +24,7 @@ jobs:
|
||||||
npm run test:generate
|
npm run test:generate
|
||||||
npm test -- --verbose
|
npm test -- --verbose
|
||||||
working-directory: ./tests-ui
|
working-directory: ./tests-ui
|
||||||
|
- name: Run Unit Tests
|
||||||
|
run: |
|
||||||
|
pip install -r tests-unit/requirements.txt
|
||||||
|
python -m pytest tests-unit
|
||||||
|
|
|
@ -18,3 +18,4 @@ venv/
|
||||||
/tests-ui/data/object_info.json
|
/tests-ui/data/object_info.json
|
||||||
/user/
|
/user/
|
||||||
*.log
|
*.log
|
||||||
|
web_custom_versions/
|
|
@ -0,0 +1,187 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
class Asset(TypedDict):
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class Release(TypedDict):
|
||||||
|
id: int
|
||||||
|
tag_name: str
|
||||||
|
name: str
|
||||||
|
prerelease: bool
|
||||||
|
created_at: str
|
||||||
|
published_at: str
|
||||||
|
body: str
|
||||||
|
assets: NotRequired[list[Asset]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrontEndProvider:
|
||||||
|
owner: str
|
||||||
|
repo: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def folder_name(self) -> str:
|
||||||
|
return f"{self.owner}_{self.repo}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def release_url(self) -> str:
|
||||||
|
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def all_releases(self) -> list[Release]:
|
||||||
|
releases = []
|
||||||
|
api_url = self.release_url
|
||||||
|
while api_url:
|
||||||
|
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
||||||
|
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||||
|
releases.extend(response.json())
|
||||||
|
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
||||||
|
if "next" in response.links:
|
||||||
|
api_url = response.links["next"]["url"]
|
||||||
|
else:
|
||||||
|
api_url = None
|
||||||
|
return releases
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def latest_release(self) -> Release:
|
||||||
|
latest_release_url = f"{self.release_url}/latest"
|
||||||
|
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
||||||
|
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def get_release(self, version: str) -> Release:
|
||||||
|
if version == "latest":
|
||||||
|
return self.latest_release
|
||||||
|
else:
|
||||||
|
for release in self.all_releases:
|
||||||
|
if release["tag_name"] in [version, f"v{version}"]:
|
||||||
|
return release
|
||||||
|
raise ValueError(f"Version {version} not found in releases")
|
||||||
|
|
||||||
|
|
||||||
|
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||||
|
"""Download dist.zip from github release."""
|
||||||
|
asset_url = None
|
||||||
|
for asset in release.get("assets", []):
|
||||||
|
if asset["name"] == "dist.zip":
|
||||||
|
asset_url = asset["url"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if not asset_url:
|
||||||
|
raise ValueError("dist.zip not found in the release assets")
|
||||||
|
|
||||||
|
# Use a temporary file to download the zip content
|
||||||
|
with tempfile.TemporaryFile() as tmp_file:
|
||||||
|
headers = {"Accept": "application/octet-stream"}
|
||||||
|
response = requests.get(
|
||||||
|
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
||||||
|
)
|
||||||
|
response.raise_for_status() # Ensure we got a successful response
|
||||||
|
|
||||||
|
# Write the content to the temporary file
|
||||||
|
tmp_file.write(response.content)
|
||||||
|
|
||||||
|
# Go back to the beginning of the temporary file
|
||||||
|
tmp_file.seek(0)
|
||||||
|
|
||||||
|
# Extract the zip file content to the destination path
|
||||||
|
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(destination_path)
|
||||||
|
|
||||||
|
|
||||||
|
class FrontendManager:
|
||||||
|
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
||||||
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
value (str): The version string to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: A tuple containing provider name and version.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
argparse.ArgumentTypeError: If the version string is invalid.
|
||||||
|
"""
|
||||||
|
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$"
|
||||||
|
match_result = re.match(VERSION_PATTERN, value)
|
||||||
|
if match_result is None:
|
||||||
|
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||||
|
|
||||||
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||||
|
"""
|
||||||
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_string (str): The version string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the initialized frontend.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If there is an error during the initialization process.
|
||||||
|
main error source might be request timeout or invalid URL.
|
||||||
|
"""
|
||||||
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
|
provider = FrontEndProvider(repo_owner, repo_name)
|
||||||
|
release = provider.get_release(version)
|
||||||
|
|
||||||
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
|
web_root = str(
|
||||||
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
|
)
|
||||||
|
if not os.path.exists(web_root):
|
||||||
|
os.makedirs(web_root, exist_ok=True)
|
||||||
|
logging.info(
|
||||||
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
|
provider.folder_name,
|
||||||
|
semantic_version,
|
||||||
|
web_root,
|
||||||
|
)
|
||||||
|
logging.debug(release)
|
||||||
|
download_release_asset_zip(release, destination_path=web_root)
|
||||||
|
return web_root
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_frontend(cls, version_string: str) -> str:
|
||||||
|
"""
|
||||||
|
Initializes the frontend with the specified version string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_string (str): The version string to initialize the frontend with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path of the initialized frontend.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cls.init_frontend_unsafe(version_string)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Failed to initialize frontend: %s", e)
|
||||||
|
logging.info("Falling back to the default frontend.")
|
||||||
|
return cls.DEFAULT_FRONTEND_PATH
|
|
@ -1,7 +1,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
import comfy.options
|
import comfy.options
|
||||||
|
|
||||||
|
|
||||||
class EnumAction(argparse.Action):
|
class EnumAction(argparse.Action):
|
||||||
"""
|
"""
|
||||||
Argparse action for handling Enums
|
Argparse action for handling Enums
|
||||||
|
@ -124,6 +127,38 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||||
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||||
|
|
||||||
|
# The default built-in provider hosted under web/
|
||||||
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--front-end-version",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_VERSION_STRING,
|
||||||
|
help="""
|
||||||
|
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
|
||||||
|
download available frontend implementations from GitHub releases.
|
||||||
|
|
||||||
|
The version string should be in the format of:
|
||||||
|
[repoOwner]/[repoName]@[version]
|
||||||
|
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||||
|
"""Validate if the given path is a directory."""
|
||||||
|
if path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||||
|
return path
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--front-end-root",
|
||||||
|
type=is_valid_directory,
|
||||||
|
default=None,
|
||||||
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
|
)
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
inference: mark as inference test (deselect with '-m "not inference"')
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
testpaths = tests
|
testpaths =
|
||||||
|
tests
|
||||||
|
tests-unit
|
||||||
addopts = -s
|
addopts = -s
|
||||||
|
pythonpath = .
|
||||||
|
|
11
server.py
11
server.py
|
@ -25,9 +25,10 @@ import mimetypes
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
|
@ -83,8 +84,12 @@ class PromptServer():
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.web_root = os.path.join(os.path.dirname(
|
self.web_root = (
|
||||||
os.path.realpath(__file__)), "web")
|
FrontendManager.init_frontend(args.front_end_version)
|
||||||
|
if args.front_end_root is None
|
||||||
|
else args.front_end_root
|
||||||
|
)
|
||||||
|
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Pytest Unit Tests
|
||||||
|
|
||||||
|
## Install test dependencies
|
||||||
|
|
||||||
|
`pip install -r tests-units/requirements.txt`
|
||||||
|
|
||||||
|
## Run tests
|
||||||
|
`pytest tests-units/`
|
|
@ -0,0 +1,100 @@
|
||||||
|
import argparse
|
||||||
|
import pytest
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from app.frontend_management import (
|
||||||
|
FrontendManager,
|
||||||
|
FrontEndProvider,
|
||||||
|
Release,
|
||||||
|
)
|
||||||
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_releases():
|
||||||
|
return [
|
||||||
|
Release(
|
||||||
|
id=1,
|
||||||
|
tag_name="1.0.0",
|
||||||
|
name="Release 1.0.0",
|
||||||
|
prerelease=False,
|
||||||
|
created_at="2022-01-01T00:00:00Z",
|
||||||
|
published_at="2022-01-01T00:00:00Z",
|
||||||
|
body="Release notes for 1.0.0",
|
||||||
|
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||||
|
),
|
||||||
|
Release(
|
||||||
|
id=2,
|
||||||
|
tag_name="2.0.0",
|
||||||
|
name="Release 2.0.0",
|
||||||
|
prerelease=False,
|
||||||
|
created_at="2022-02-01T00:00:00Z",
|
||||||
|
published_at="2022-02-01T00:00:00Z",
|
||||||
|
body="Release notes for 2.0.0",
|
||||||
|
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_provider(mock_releases):
|
||||||
|
provider = FrontEndProvider(
|
||||||
|
owner="test-owner",
|
||||||
|
repo="test-repo",
|
||||||
|
)
|
||||||
|
provider.all_releases = mock_releases
|
||||||
|
provider.latest_release = mock_releases[1]
|
||||||
|
FrontendManager.PROVIDERS = [provider]
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_release(mock_provider, mock_releases):
|
||||||
|
version = "1.0.0"
|
||||||
|
release = mock_provider.get_release(version)
|
||||||
|
assert release == mock_releases[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_release_latest(mock_provider, mock_releases):
|
||||||
|
version = "latest"
|
||||||
|
release = mock_provider.get_release(version)
|
||||||
|
assert release == mock_releases[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_release_invalid_version(mock_provider):
|
||||||
|
version = "invalid"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mock_provider.get_release(version)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_default():
|
||||||
|
version_string = DEFAULT_VERSION_STRING
|
||||||
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_invalid_version():
|
||||||
|
version_string = "test-owner/test-repo@1.100.99"
|
||||||
|
with pytest.raises(HTTPError):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_frontend_invalid_provider():
|
||||||
|
version_string = "invalid/invalid@latest"
|
||||||
|
with pytest.raises(HTTPError):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_version_string():
|
||||||
|
version_string = "owner/repo@1.0.0"
|
||||||
|
repo_owner, repo_name, version = FrontendManager.parse_version_string(
|
||||||
|
version_string
|
||||||
|
)
|
||||||
|
assert repo_owner == "owner"
|
||||||
|
assert repo_name == "repo"
|
||||||
|
assert version == "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_version_string_invalid():
|
||||||
|
version_string = "invalid"
|
||||||
|
with pytest.raises(argparse.ArgumentTypeError):
|
||||||
|
FrontendManager.parse_version_string(version_string)
|
|
@ -0,0 +1 @@
|
||||||
|
pytest>=7.8.0
|
Loading…
Reference in New Issue