makes cors a cli parameter

This commit is contained in:
EllangoK 2023-04-06 15:06:22 -04:00
parent 7d62d89f93
commit 48efae1608
2 changed files with 25 additions and 14 deletions

View File

@ -4,8 +4,10 @@ parser = argparse.ArgumentParser()
parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
@ -13,7 +15,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
vram_group = parser.add_mutually_exclusive_group() vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")

View File

@ -18,6 +18,7 @@ except ImportError:
sys.exit() sys.exit()
import mimetypes import mimetypes
from comfy.cli_args import args
@web.middleware @web.middleware
@ -27,18 +28,22 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache') response.headers.setdefault('Cache-Control', 'no-cache')
return response return response
@web.middleware def create_cors_middleware(allowed_origin: str):
async def cors_middleware(request: web.Request, handler): @web.middleware
if request.method == "OPTIONS": async def cors_middleware(request: web.Request, handler):
# Pre-flight request. Reply successfully: if request.method == "OPTIONS":
response = web.Response() # Pre-flight request. Reply successfully:
else: response = web.Response()
response = await handler(request) else:
response.headers['Access-Control-Allow-Origin'] = '*' response = await handler(request)
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
return response response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
return cors_middleware
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
@ -50,7 +55,12 @@ class PromptServer():
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.number = 0 self.number = 0
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware])
middlewares = [cache_control]
if args.cors:
middlewares.append(create_cors_middleware(args.cors))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web") os.path.realpath(__file__)), "web")