From 36c83cdbba89a01830816240553f16dd72377cd9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Sep 2024 01:00:31 -0400 Subject: [PATCH] Limit origin check to when host is loopback. This should still prevent the exploit without breaking things for people who use reverse proxies. --- server.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index a1fca49c..a2c078e4 100644 --- a/server.py +++ b/server.py @@ -12,6 +12,8 @@ import json import glob import struct import ssl +import socket +import ipaddress from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from io import BytesIO @@ -80,6 +82,32 @@ def create_cors_middleware(allowed_origin: str): return cors_middleware +def is_loopback(host): + if host is None: + return False + try: + if ipaddress.ip_address(host).is_loopback: + return True + else: + return False + except: + pass + + loopback = False + for family in (socket.AF_INET, socket.AF_INET6): + try: + r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM) + for family, _, _, _, sockaddr in r: + if not ipaddress.ip_address(sockaddr[0]).is_loopback: + return loopback + else: + loopback = True + except socket.gaierror: + pass + + return loopback + + def create_origin_only_middleware(): @web.middleware async def origin_only_middleware(request: web.Request, handler): @@ -93,12 +121,16 @@ def create_origin_only_middleware(): parsed = urllib.parse.urlparse(origin) origin_domain = parsed.netloc.lower() host_domain_parsed = urllib.parse.urlsplit('//' + host_domain) + + #limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit + loopback = is_loopback(host_domain_parsed.hostname) + if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host host_domain = host_domain_parsed.hostname if host_domain_parsed.port is None: origin_domain = parsed.hostname - if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: + if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: if host_domain != origin_domain: logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) return web.Response(status=403)