diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 3f03533e..1bd8d736 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -58,16 +58,23 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - f = float(lh) / float(lw) - fh = f ** 0.5 - fw = (1/f) ** 0.5 - S = mask.size(1) ** 0.5 - w = int(0.5 + S * fw) - h = int(0.5 + S * fh) + total = mask.shape[-1] + x = round(math.sqrt((lh / lw) * total)) + xx = None + for i in range(0, math.floor(math.sqrt(total) / 2)): + for j in [(x + i), max(1, x - i)]: + if total % j == 0: + xx = j + break + if xx is not None: + break + + x = xx + y = total // x # Reshape mask = ( - mask.reshape(b, h, w) + mask.reshape(b, x, y) .unsqueeze(1) .type(attn.dtype) )