Add `FrontendManager` to manage non-default front-end impl (#3897)
* Add frontend manager * Add tests * nit * Add unit test to github CI * Fix path * nit * ignore * Add logging * Install test deps * Remove 'stable' keyword support * Update test * Add web-root arg * Rename web-root to front-end-root * Add test on non-exist version number * Use repo owner/name to replace hard coded provider list * Inline cmd args * nit * Fix unit test
This commit is contained in:
parent
33346fd9b8
commit
99458e8aca
|
@ -24,3 +24,7 @@ jobs:
|
|||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
|
|
|
@ -17,4 +17,5 @@ venv/
|
|||
!/web/extensions/core/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
*.log
|
||||
web_custom_versions/
|
|
@ -0,0 +1,187 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class Asset(TypedDict):
|
||||
url: str
|
||||
|
||||
|
||||
class Release(TypedDict):
|
||||
id: int
|
||||
tag_name: str
|
||||
name: str
|
||||
prerelease: bool
|
||||
created_at: str
|
||||
published_at: str
|
||||
body: str
|
||||
assets: NotRequired[list[Asset]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontEndProvider:
|
||||
owner: str
|
||||
repo: str
|
||||
|
||||
@property
|
||||
def folder_name(self) -> str:
|
||||
return f"{self.owner}_{self.repo}"
|
||||
|
||||
@property
|
||||
def release_url(self) -> str:
|
||||
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
||||
|
||||
@cached_property
|
||||
def all_releases(self) -> list[Release]:
|
||||
releases = []
|
||||
api_url = self.release_url
|
||||
while api_url:
|
||||
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
releases.extend(response.json())
|
||||
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
||||
if "next" in response.links:
|
||||
api_url = response.links["next"]["url"]
|
||||
else:
|
||||
api_url = None
|
||||
return releases
|
||||
|
||||
@cached_property
|
||||
def latest_release(self) -> Release:
|
||||
latest_release_url = f"{self.release_url}/latest"
|
||||
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
return release
|
||||
raise ValueError(f"Version {version} not found in releases")
|
||||
|
||||
|
||||
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
"""Download dist.zip from github release."""
|
||||
asset_url = None
|
||||
for asset in release.get("assets", []):
|
||||
if asset["name"] == "dist.zip":
|
||||
asset_url = asset["url"]
|
||||
break
|
||||
|
||||
if not asset_url:
|
||||
raise ValueError("dist.zip not found in the release assets")
|
||||
|
||||
# Use a temporary file to download the zip content
|
||||
with tempfile.TemporaryFile() as tmp_file:
|
||||
headers = {"Accept": "application/octet-stream"}
|
||||
response = requests.get(
|
||||
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
||||
)
|
||||
response.raise_for_status() # Ensure we got a successful response
|
||||
|
||||
# Write the content to the temporary file
|
||||
tmp_file.write(response.content)
|
||||
|
||||
# Go back to the beginning of the temporary file
|
||||
tmp_file.seek(0)
|
||||
|
||||
# Extract the zip file content to the destination path
|
||||
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_path)
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
web_root = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
|
@ -1,7 +1,10 @@
|
|||
import argparse
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
import comfy.options
|
||||
|
||||
|
||||
class EnumAction(argparse.Action):
|
||||
"""
|
||||
Argparse action for handling Enums
|
||||
|
@ -124,6 +127,38 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-version",
|
||||
type=str,
|
||||
default=DEFAULT_VERSION_STRING,
|
||||
help="""
|
||||
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
|
||||
download available frontend implementations from GitHub releases.
|
||||
|
||||
The version string should be in the format of:
|
||||
[repoOwner]/[repoName]@[version]
|
||||
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
|
||||
""",
|
||||
)
|
||||
|
||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate if the given path is a directory."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
if not os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||
return path
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-root",
|
||||
type=is_valid_directory,
|
||||
default=None,
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
testpaths = tests
|
||||
addopts = -s
|
||||
testpaths =
|
||||
tests
|
||||
tests-unit
|
||||
addopts = -s
|
||||
pythonpath = .
|
||||
|
|
11
server.py
11
server.py
|
@ -25,9 +25,10 @@ import mimetypes
|
|||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
|
@ -83,8 +84,12 @@ class PromptServer():
|
|||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.web_root = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), "web")
|
||||
self.web_root = (
|
||||
FrontendManager.init_frontend(args.front_end_version)
|
||||
if args.front_end_root is None
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Pytest Unit Tests
|
||||
|
||||
## Install test dependencies
|
||||
|
||||
`pip install -r tests-units/requirements.txt`
|
||||
|
||||
## Run tests
|
||||
`pytest tests-units/`
|
|
@ -0,0 +1,100 @@
|
|||
import argparse
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from app.frontend_management import (
|
||||
FrontendManager,
|
||||
FrontEndProvider,
|
||||
Release,
|
||||
)
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_releases():
|
||||
return [
|
||||
Release(
|
||||
id=1,
|
||||
tag_name="1.0.0",
|
||||
name="Release 1.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-01-01T00:00:00Z",
|
||||
published_at="2022-01-01T00:00:00Z",
|
||||
body="Release notes for 1.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
Release(
|
||||
id=2,
|
||||
tag_name="2.0.0",
|
||||
name="Release 2.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-02-01T00:00:00Z",
|
||||
published_at="2022-02-01T00:00:00Z",
|
||||
body="Release notes for 2.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_releases):
|
||||
provider = FrontEndProvider(
|
||||
owner="test-owner",
|
||||
repo="test-repo",
|
||||
)
|
||||
provider.all_releases = mock_releases
|
||||
provider.latest_release = mock_releases[1]
|
||||
FrontendManager.PROVIDERS = [provider]
|
||||
return provider
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
version = "1.0.0"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[0]
|
||||
|
||||
|
||||
def test_get_release_latest(mock_provider, mock_releases):
|
||||
version = "latest"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[1]
|
||||
|
||||
|
||||
def test_get_release_invalid_version(mock_provider):
|
||||
version = "invalid"
|
||||
with pytest.raises(ValueError):
|
||||
mock_provider.get_release(version)
|
||||
|
||||
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
version_string = "test-owner/test-repo@1.100.99"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_invalid_provider():
|
||||
version_string = "invalid/invalid@latest"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_parse_version_string():
|
||||
version_string = "owner/repo@1.0.0"
|
||||
repo_owner, repo_name, version = FrontendManager.parse_version_string(
|
||||
version_string
|
||||
)
|
||||
assert repo_owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
assert version == "1.0.0"
|
||||
|
||||
|
||||
def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
|
@ -0,0 +1 @@
|
|||
pytest>=7.8.0
|
Loading…
Reference in New Issue