proper fix for sag.
This commit is contained in:
parent
8b90e50979
commit
9c1ed58ef2
|
@ -58,16 +58,23 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||||
# Global Average Pool
|
# Global Average Pool
|
||||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||||
|
|
||||||
f = float(lh) / float(lw)
|
total = mask.shape[-1]
|
||||||
fh = f ** 0.5
|
x = round(math.sqrt((lh / lw) * total))
|
||||||
fw = (1/f) ** 0.5
|
xx = None
|
||||||
S = mask.size(1) ** 0.5
|
for i in range(0, math.floor(math.sqrt(total) / 2)):
|
||||||
w = int(0.5 + S * fw)
|
for j in [(x + i), max(1, x - i)]:
|
||||||
h = int(0.5 + S * fh)
|
if total % j == 0:
|
||||||
|
xx = j
|
||||||
|
break
|
||||||
|
if xx is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
x = xx
|
||||||
|
y = total // x
|
||||||
|
|
||||||
# Reshape
|
# Reshape
|
||||||
mask = (
|
mask = (
|
||||||
mask.reshape(b, h, w)
|
mask.reshape(b, x, y)
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.type(attn.dtype)
|
.type(attn.dtype)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue