Support zeroing out text embeddings with the attention mask.
This commit is contained in:
parent
6cd8ffc465
commit
742d5720d1
|
@ -68,7 +68,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
|
return_projected_pooled=True): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
|
@ -90,6 +91,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
|
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
self.enable_attention_masks = enable_attention_masks
|
self.enable_attention_masks = enable_attention_masks
|
||||||
|
self.zero_out_masked = zero_out_masked
|
||||||
|
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
|
@ -179,9 +181,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs[0]
|
z = outputs[0].float()
|
||||||
else:
|
else:
|
||||||
z = outputs[1]
|
z = outputs[1].float()
|
||||||
|
|
||||||
|
if self.zero_out_masked and attention_mask is not None:
|
||||||
|
z *= attention_mask.unsqueeze(-1).float()
|
||||||
|
|
||||||
pooled_output = None
|
pooled_output = None
|
||||||
if len(outputs) >= 3:
|
if len(outputs) >= 3:
|
||||||
|
@ -190,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
elif outputs[2] is not None:
|
elif outputs[2] is not None:
|
||||||
pooled_output = outputs[2].float()
|
pooled_output = outputs[2].float()
|
||||||
|
|
||||||
return z.float(), pooled_output
|
return z, pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
Loading…
Reference in New Issue