From ce35d8c659cb8340aa4c758de7cfc42cf311f7f3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Jul 2023 15:07:39 -0400 Subject: [PATCH] Lower latency by batching some text encoder inputs. --- comfy/sd1_clip.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ffcb849d..27b2f18e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -10,21 +10,29 @@ import contextlib class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): - z_empty, _ = self.encode(self.empty_tokens) - output = [] - first_pooled = None + to_encode = list(self.empty_tokens) for x in token_weight_pairs: - tokens = [list(map(lambda a: a[0], x))] - z, pooled = self.encode(tokens) - if first_pooled is None: - first_pooled = pooled + tokens = list(map(lambda a: a[0], x)) + to_encode.append(tokens) + + out, pooled = self.encode(to_encode) + z_empty = out[0:1] + if pooled.shape[0] > 1: + first_pooled = pooled[1:2] + else: + first_pooled = pooled[0:1] + + output = [] + for i in range(1, out.shape[0]): + z = out[i:i+1] for i in range(len(z)): for j in range(len(z[i])): - weight = x[j][1] + weight = token_weight_pairs[i - 1][j][1] z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] - output += [z] + output.append(z) + if (len(output) == 0): - return self.encode(self.empty_tokens) + return z_empty, first_pooled return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):