From 539ff487a81f4ed4f51ca9ece57756b573e52190 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 3 Apr 2023 15:49:28 -0400 Subject: [PATCH] Pull latest tomesd code from upstream. --- comfy/ldm/modules/tomesd.py | 69 ++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 1eafcd0a..6a13b80c 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -1,4 +1,4 @@ - +#Taken from: https://github.com/dbolya/tomesd import torch from typing import Tuple, Callable @@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None): return x +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + else: + return torch.gather(input, dim, index) + + def bipartite_soft_matching_random2d(metric: torch.Tensor, w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False) -> Tuple[Callable, Callable]: """ Partitions the tokens into src and dst and merges r tokens from src to dst. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. - Args: - metric [B, N, C]: metric to use for similarity - w: image width in tokens @@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, if r <= 0: return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather with torch.no_grad(): hsy, wsx = h // sy, w // sx # For each sy by sx kernel, randomly assign one token to be dst and the rest src - idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device) - if no_rand: - rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) + rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) else: - rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) + rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) - idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) - idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) - rand_idx = idx_buffer.argsort(dim=1) + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) - num_dst = int((1 / (sx*sy)) * N) + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx a_idx = rand_idx[:, num_dst:, :] # src b_idx = rand_idx[:, :num_dst, :] # dst def split(x): C = x.shape[-1] - src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) - dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) return src, dst + # Cosine similarity between A and B metric = metric / metric.norm(dim=-1, keepdim=True) a, b = split(metric) scores = a @ b.transpose(-1, -2) @@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, # Can't reduce more than the # tokens in src r = min(a.shape[1], r) + # Find the most similar greedily node_max, node_idx = scores.max(dim=-1) edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] unm_idx = edge_idx[..., r:, :] # Unmerged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens - dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) n, t1, c = src.shape - unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) - src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) return torch.cat([unm, dst], dim=1) @@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] _, _, c = unm.shape - src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) # Combine back to the original shape out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) - out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) - out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) return out @@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, def get_functions(x, ratio, original_shape): b, c, original_h, original_w = original_shape original_tokens = original_h * original_w - downsample = int(math.sqrt(original_tokens // x.shape[1])) + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) stride_x = 2 stride_y = 2 max_downsample = 1 if downsample <= max_downsample: - w = original_w // downsample - h = original_h // downsample + w = int(math.ceil(original_w / downsample)) + h = int(math.ceil(original_h / downsample)) r = int(x.shape[1] * ratio) no_rand = False m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)