Limit origin check to when host is loopback.

This should still prevent the exploit without breaking things for people
who use reverse proxies.
This commit is contained in:
comfyanonymous 2024-09-11 01:00:31 -04:00
parent 81778a7feb
commit 36c83cdbba
1 changed files with 33 additions and 1 deletions

View File

@ -12,6 +12,8 @@ import json
import glob import glob
import struct import struct
import ssl import ssl
import socket
import ipaddress
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -80,6 +82,32 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware(): def create_origin_only_middleware():
@web.middleware @web.middleware
async def origin_only_middleware(request: web.Request, handler): async def origin_only_middleware(request: web.Request, handler):
@ -93,12 +121,16 @@ def create_origin_only_middleware():
parsed = urllib.parse.urlparse(origin) parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower() origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain) host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None: if host_domain_parsed.port is None:
origin_domain = parsed.hostname origin_domain = parsed.hostname
if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain: if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403) return web.Response(status=403)