From 742d5720d1b128c78266bfd7156fb578d664a95a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 9 Jun 2024 16:41:04 -0400 Subject: [PATCH] Support zeroing out text embeddings with the attention mask. --- comfy/sd1_clip.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index e7ebf046..2729f14d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -68,7 +68,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, + return_projected_pooled=True): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -90,6 +91,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = enable_attention_masks + self.zero_out_masked = zero_out_masked self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled @@ -179,9 +181,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": - z = outputs[0] + z = outputs[0].float() else: - z = outputs[1] + z = outputs[1].float() + + if self.zero_out_masked and attention_mask is not None: + z *= attention_mask.unsqueeze(-1).float() pooled_output = None if len(outputs) >= 3: @@ -190,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): elif outputs[2] is not None: pooled_output = outputs[2].float() - return z.float(), pooled_output + return z, pooled_output def encode(self, tokens): return self(tokens)