makes cors a cli parameter
This commit is contained in:
parent
7d62d89f93
commit
48efae1608
|
@ -4,8 +4,10 @@ parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
|
parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
|
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
|
@ -13,7 +15,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
|
|
36
server.py
36
server.py
|
@ -18,6 +18,7 @@ except ImportError:
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
|
@ -27,18 +28,22 @@ async def cache_control(request: web.Request, handler):
|
||||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@web.middleware
|
def create_cors_middleware(allowed_origin: str):
|
||||||
async def cors_middleware(request: web.Request, handler):
|
@web.middleware
|
||||||
if request.method == "OPTIONS":
|
async def cors_middleware(request: web.Request, handler):
|
||||||
# Pre-flight request. Reply successfully:
|
if request.method == "OPTIONS":
|
||||||
response = web.Response()
|
# Pre-flight request. Reply successfully:
|
||||||
else:
|
response = web.Response()
|
||||||
response = await handler(request)
|
else:
|
||||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
response = await handler(request)
|
||||||
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
|
||||||
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
||||||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
||||||
return response
|
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||||
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||||
|
return response
|
||||||
|
|
||||||
|
return cors_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
|
@ -50,7 +55,12 @@ class PromptServer():
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
self.number = 0
|
self.number = 0
|
||||||
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware])
|
|
||||||
|
middlewares = [cache_control]
|
||||||
|
if args.cors:
|
||||||
|
middlewares.append(create_cors_middleware(args.cors))
|
||||||
|
|
||||||
|
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.web_root = os.path.join(os.path.dirname(
|
self.web_root = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), "web")
|
os.path.realpath(__file__)), "web")
|
||||||
|
|
Loading…
Reference in New Issue