If host does not contain a port only compare the hostnames.
This commit is contained in:
parent
cd4955367e
commit
54fca4a218
13
server.py
13
server.py
|
@ -83,17 +83,22 @@ def create_cors_middleware(allowed_origin: str):
|
||||||
def create_origin_only_middleware():
|
def create_origin_only_middleware():
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def origin_only_middleware(request: web.Request, handler):
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
|
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||||
|
#in that case the Host and Origin hostnames won't match
|
||||||
|
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||||
if 'Host' in request.headers and 'Origin' in request.headers:
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
host = request.headers['Host']
|
host = request.headers['Host']
|
||||||
origin = request.headers['Origin']
|
origin = request.headers['Origin']
|
||||||
host_domain = host.lower()
|
host_domain = host.lower()
|
||||||
parsed = urllib.parse.urlparse(origin)
|
parsed = urllib.parse.urlparse(origin)
|
||||||
origin_domain = parsed.netloc.lower()
|
origin_domain = parsed.netloc.lower()
|
||||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
result = urllib.parse.urlsplit('//' + host_domain)
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
host_domain = result.hostname
|
host_domain = host_domain_parsed.hostname
|
||||||
|
if host_domain_parsed.port is None:
|
||||||
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
if len(host_domain) > 0 and len(origin_domain) > 0:
|
if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
if host_domain != origin_domain:
|
if host_domain != origin_domain:
|
||||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
|
|
Loading…
Reference in New Issue