diff --git a/server.py b/server.py index 93147f6c..a1fca49c 100644 --- a/server.py +++ b/server.py @@ -83,17 +83,22 @@ def create_cors_middleware(allowed_origin: str): def create_origin_only_middleware(): @web.middleware async def origin_only_middleware(request: web.Request, handler): + #this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason. + #in that case the Host and Origin hostnames won't match + #I know the proper fix would be to add a cookie but this should take care of the problem in the meantime if 'Host' in request.headers and 'Origin' in request.headers: host = request.headers['Host'] origin = request.headers['Origin'] host_domain = host.lower() parsed = urllib.parse.urlparse(origin) origin_domain = parsed.netloc.lower() - if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers - result = urllib.parse.urlsplit('//' + host_domain) - host_domain = result.hostname + host_domain_parsed = urllib.parse.urlsplit('//' + host_domain) + if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host + host_domain = host_domain_parsed.hostname + if host_domain_parsed.port is None: + origin_domain = parsed.hostname - if len(host_domain) > 0 and len(origin_domain) > 0: + if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: if host_domain != origin_domain: logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) return web.Response(status=403)