From 03eadbb53c82954ae5e42efa44903ed1319ff3d6 Mon Sep 17 00:00:00 2001 From: asagi4 <130366179+asagi4@users.noreply.github.com> Date: Wed, 6 Dec 2023 21:12:49 +0200 Subject: [PATCH] Make HyperTile deterministic --- comfy_extras/nodes_hypertile.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 0d7d4c95..15736b83 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -2,9 +2,10 @@ import math from einops import rearrange -import random +# Use torch rng for consistency across generations +from torch import randint -def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) # All big divisors of value (inclusive) @@ -12,8 +13,7 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter ns = [value // i for i in divisors[:max_options]] # has at least 1 element - random.seed(counter) - idx = random.randint(0, len(ns) - 1) + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() return ns[idx] @@ -42,7 +42,6 @@ class HyperTile: latent_tile_size = max(32, tile_size) // 8 self.temp = None - self.counter = 1 def hypertile_in(q, k, v, extra_options): if q.shape[-1] in apply_to: @@ -53,10 +52,8 @@ class HyperTile: h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 - nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 - nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)