Use own clip vision model implementation.
This commit is contained in:
parent
97015b6b38
commit
174eba8e95
|
@ -57,12 +57,7 @@ class CLIPEncoder(torch.nn.Module):
|
||||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
||||||
|
|
||||||
def forward(self, x, mask=None, intermediate_output=None):
|
def forward(self, x, mask=None, intermediate_output=None):
|
||||||
optimized_attention = optimized_attention_for_device(x.device, mask=True)
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
|
||||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
|
||||||
if mask is not None:
|
|
||||||
mask += causal_mask
|
|
||||||
else:
|
|
||||||
mask = causal_mask
|
|
||||||
|
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
if intermediate_output < 0:
|
if intermediate_output < 0:
|
||||||
|
@ -105,6 +100,12 @@ class CLIPTextModel_(torch.nn.Module):
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
|
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||||
|
if mask is not None:
|
||||||
|
mask += causal_mask
|
||||||
|
else:
|
||||||
|
mask = causal_mask
|
||||||
|
|
||||||
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
||||||
x = self.final_layer_norm(x)
|
x = self.final_layer_norm(x)
|
||||||
if i is not None and final_layer_norm_intermediate:
|
if i is not None and final_layer_norm_intermediate:
|
||||||
|
@ -128,3 +129,60 @@ class CLIPTextModel(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.text_model(*args, **kwargs)
|
return self.text_model(*args, **kwargs)
|
||||||
|
|
||||||
|
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.patch_embedding = operations.Conv2d(
|
||||||
|
in_channels=num_channels,
|
||||||
|
out_channels=embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
num_positions = num_patches + 1
|
||||||
|
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVision(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
embed_dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
|
||||||
|
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
|
||||||
|
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||||
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
|
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
x = self.embeddings(pixel_values)
|
||||||
|
x = self.pre_layrnorm(x)
|
||||||
|
#TODO: attention_mask?
|
||||||
|
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||||
|
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||||
|
return x, i, pooled_output
|
||||||
|
|
||||||
|
class CLIPVisionModelProjection(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||||
|
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
x = self.vision_model(*args, **kwargs)
|
||||||
|
out = self.visual_projection(x[2])
|
||||||
|
return (x[0], x[1], out)
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
|
|
||||||
from .utils import load_torch_file, transformers_convert, common_upscale
|
from .utils import load_torch_file, transformers_convert, common_upscale
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.clip_model
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
def __setitem__(self, key, item):
|
||||||
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224):
|
def clip_preprocess(image, size=224):
|
||||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||||
|
@ -22,17 +29,16 @@ def clip_preprocess(image, size=224):
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
with open(json_config) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
|
||||||
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
|
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops)
|
||||||
with modeling_utils.no_init_weights():
|
|
||||||
self.model = CLIPVisionModelWithProjection(config)
|
|
||||||
self.model.to(self.dtype)
|
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
|
@ -48,17 +54,12 @@ class ClipVisionModel():
|
||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||||
|
|
||||||
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
|
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||||
|
|
||||||
for k in outputs:
|
|
||||||
t = outputs[k]
|
|
||||||
if t is not None:
|
|
||||||
if k == 'hidden_states':
|
|
||||||
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
|
|
||||||
outputs["hidden_states"] = None
|
|
||||||
else:
|
|
||||||
outputs[k] = t.to(comfy.model_management.intermediate_device())
|
|
||||||
|
|
||||||
|
outputs = Output()
|
||||||
|
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||||
|
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||||
|
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def convert_to_transformers(sd, prefix):
|
def convert_to_transformers(sd, prefix):
|
||||||
|
|
Loading…
Reference in New Issue