From 48efae16084b423166f9a1930b989489169d22cf Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 6 Apr 2023 15:06:22 -0400 Subject: [PATCH] makes cors a cli parameter --- comfy/cli_args.py | 3 ++- server.py | 36 +++++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a27dc7a7..5133e0ae 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -4,8 +4,10 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") attn_group = parser.add_mutually_exclusive_group() @@ -13,7 +15,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") -parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") diff --git a/server.py b/server.py index 005bf9b2..a9c0b459 100644 --- a/server.py +++ b/server.py @@ -18,6 +18,7 @@ except ImportError: sys.exit() import mimetypes +from comfy.cli_args import args @web.middleware @@ -27,18 +28,22 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response -@web.middleware -async def cors_middleware(request: web.Request, handler): - if request.method == "OPTIONS": - # Pre-flight request. Reply successfully: - response = web.Response() - else: - response = await handler(request) - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' - response.headers['Access-Control-Allow-Credentials'] = 'true' - return response +def create_cors_middleware(allowed_origin: str): + @web.middleware + async def cors_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Access-Control-Allow-Origin'] = allowed_origin + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + + return cors_middleware class PromptServer(): def __init__(self, loop): @@ -50,7 +55,12 @@ class PromptServer(): self.loop = loop self.messages = asyncio.Queue() self.number = 0 - self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware]) + + middlewares = [cache_control] + if args.cors: + middlewares.append(create_cors_middleware(args.cors)) + + self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web")