Lower latency by batching some text encoder inputs.
This commit is contained in:
parent
3b6fe51c1d
commit
ce35d8c659
|
@ -10,21 +10,29 @@ import contextlib
|
||||||
|
|
||||||
class ClipTokenWeightEncoder:
|
class ClipTokenWeightEncoder:
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
z_empty, _ = self.encode(self.empty_tokens)
|
to_encode = list(self.empty_tokens)
|
||||||
output = []
|
|
||||||
first_pooled = None
|
|
||||||
for x in token_weight_pairs:
|
for x in token_weight_pairs:
|
||||||
tokens = [list(map(lambda a: a[0], x))]
|
tokens = list(map(lambda a: a[0], x))
|
||||||
z, pooled = self.encode(tokens)
|
to_encode.append(tokens)
|
||||||
if first_pooled is None:
|
|
||||||
first_pooled = pooled
|
out, pooled = self.encode(to_encode)
|
||||||
|
z_empty = out[0:1]
|
||||||
|
if pooled.shape[0] > 1:
|
||||||
|
first_pooled = pooled[1:2]
|
||||||
|
else:
|
||||||
|
first_pooled = pooled[0:1]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for i in range(1, out.shape[0]):
|
||||||
|
z = out[i:i+1]
|
||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
for j in range(len(z[i])):
|
for j in range(len(z[i])):
|
||||||
weight = x[j][1]
|
weight = token_weight_pairs[i - 1][j][1]
|
||||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||||
output += [z]
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return self.encode(self.empty_tokens)
|
return z_empty, first_pooled
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
|
|
Loading…
Reference in New Issue