90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
|
import sd1_clip
|
||
|
import open_clip
|
||
|
import torch
|
||
|
|
||
|
class SD2ClipModel(torch.nn.Module, sd1_clip.ClipTokenWeightEncoder):
|
||
|
"""
|
||
|
Uses the OpenCLIP transformer encoder for text
|
||
|
"""
|
||
|
LAYERS = [
|
||
|
#"pooled",
|
||
|
"last",
|
||
|
"penultimate",
|
||
|
"hidden"
|
||
|
]
|
||
|
#version="laion2b_s32b_b79k"
|
||
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77,
|
||
|
freeze=True, layer="penultimate", layer_idx=None):
|
||
|
super().__init__()
|
||
|
assert layer in self.LAYERS
|
||
|
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'))
|
||
|
del model.visual
|
||
|
self.model = model
|
||
|
|
||
|
self.device = device
|
||
|
self.max_length = max_length
|
||
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||
|
if freeze:
|
||
|
self.freeze()
|
||
|
self.layer = layer
|
||
|
if self.layer == "last":
|
||
|
self.layer_idx = 0
|
||
|
elif self.layer == "penultimate":
|
||
|
self.layer_idx = 1
|
||
|
elif self.layer == "hidden":
|
||
|
assert layer_idx is not None
|
||
|
assert abs(layer_idx) < 24
|
||
|
self.clip_layer(layer_idx)
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def freeze(self):
|
||
|
self.model = self.model.eval()
|
||
|
for param in self.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
def clip_layer(self, layer_idx):
|
||
|
#layer_idx should have the same logic as the one for SD1
|
||
|
if abs(layer_idx) >= 24:
|
||
|
self.layer_idx = 0
|
||
|
else:
|
||
|
if layer_idx < 0:
|
||
|
self.layer_idx = -(layer_idx + 1)
|
||
|
else:
|
||
|
self.layer_idx = 24 - (layer_idx + 1)
|
||
|
|
||
|
def forward(self, tokens):
|
||
|
tokens = torch.LongTensor(tokens).to(self.device)
|
||
|
z = self.encode_with_transformer(tokens)
|
||
|
return z
|
||
|
|
||
|
def encode_with_transformer(self, tokens):
|
||
|
x = self.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
|
||
|
x = x + self.model.positional_embedding
|
||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||
|
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||
|
x = self.model.ln_final(x)
|
||
|
return x
|
||
|
|
||
|
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||
|
for i, r in enumerate(self.model.transformer.resblocks):
|
||
|
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||
|
break
|
||
|
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
||
|
x = checkpoint(r, x, attn_mask)
|
||
|
else:
|
||
|
x = r(x, attn_mask=attn_mask)
|
||
|
return x
|
||
|
|
||
|
def encode(self, tokens):
|
||
|
return self(tokens)
|
||
|
|
||
|
|
||
|
|
||
|
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||
|
def __init__(self, tokenizer_path=None):
|
||
|
super().__init__(tokenizer_path, pad_with_end=False)
|
||
|
|
||
|
|