From 56d802e1f3195bbc7f505345db5e5510358be025 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 5 Feb 2023 14:36:28 -0500 Subject: [PATCH] Use transformers CLIP instead of open_clip for SD2.x This should make things a bit cleaner. --- comfy/sd.py | 73 ++++++++++++++++++++------------- comfy/sd2_clip.py | 84 +++++--------------------------------- comfy/sd2_clip_config.json | 23 +++++++++++ 3 files changed, 77 insertions(+), 103 deletions(-) create mode 100644 comfy/sd2_clip_config.json diff --git a/comfy/sd.py b/comfy/sd.py index 777c9c2d..00c06004 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -40,6 +40,42 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() + keys_to_replace = { + "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", + "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", + "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", + "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", + } + + for x in keys_to_replace: + if x in sd: + sd[keys_to_replace[x]] = sd.pop(x) + + resblock_to_replace = { + "ln_1": "layer_norm1", + "ln_2": "layer_norm2", + "mlp.c_fc": "mlp.fc1", + "mlp.c_proj": "mlp.fc2", + "attn.out_proj": "self_attn.out_proj", + } + + for resblock in range(24): + for x in resblock_to_replace: + for y in ["weight", "bias"]: + k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y) + k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y) + if k in sd: + sd[k_to] = sd.pop(k) + + for y in ["weight", "bias"]: + k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y) + if k_from in sd: + weights = sd.pop(k_from) + for x in range(3): + p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y) + sd[k_to] = weights[1024*x:1024*(x + 1)] + for x in load_state_dict_to: x.load_state_dict(sd, strict=False) @@ -62,12 +98,6 @@ LORA_CLIP_MAP = { "self_attn.out_proj": "self_attn_out_proj", } -LORA_CLIP2_MAP = { - "mlp.c_fc": "mlp_fc1", - "mlp.c_proj": "mlp_fc2", - "attn.out_proj": "self_attn_out_proj", -} - LORA_UNET_MAP = { "proj_in": "proj_in", "proj_out": "proj_out", @@ -116,7 +146,7 @@ def model_lora_keys(model, key_map={}): k = "{}.{}.weight".format(tk, c) if k in sdk: lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c]) - key_map[lora_key] = (k, 0) + key_map[lora_key] = k up_counter += 1 if up_counter >= 4: counter += 1 @@ -124,7 +154,7 @@ def model_lora_keys(model, key_map={}): k = "model.diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c]) - key_map[lora_key] = (k, 0) + key_map[lora_key] = k counter = 3 for b in range(12): tk = "model.diffusion_model.output_blocks.{}.1".format(b) @@ -133,29 +163,18 @@ def model_lora_keys(model, key_map={}): k = "{}.{}.weight".format(tk, c) if k in sdk: lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c]) - key_map[lora_key] = (k, 0) + key_map[lora_key] = k up_counter += 1 if up_counter >= 4: counter += 1 counter = 0 text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" - for b in range(12): + for b in range(24): for c in LORA_CLIP_MAP: k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = (k, 0) - for b in range(24): - for c in LORA_CLIP2_MAP: - k = "model.transformer.resblocks.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP2_MAP[c]) - key_map[lora_key] = (k, 0) - k = "model.transformer.resblocks.{}.attn.in_proj_weight".format(b) - if k in sdk: - key_map[text_model_lora_key.format(b, "self_attn_q_proj")] = (k, 0) - key_map[text_model_lora_key.format(b, "self_attn_k_proj")] = (k, 1) - key_map[text_model_lora_key.format(b, "self_attn_v_proj")] = (k, 2) + key_map[lora_key] = k return key_map @@ -174,7 +193,7 @@ class ModelPatcher: p = {} model_sd = self.model.state_dict() for k in patches: - if k[0] in model_sd: + if k in model_sd: p[k] = patches[k] self.patches += [(strength, p)] return p.keys() @@ -184,8 +203,7 @@ class ModelPatcher: for p in self.patches: for k in p[1]: v = p[1][k] - key = k[0] - index = k[1] + key = k if key not in model_sd: print("could not patch. key doesn't exist in model:", k) continue @@ -199,10 +217,7 @@ class ModelPatcher: mat2 = v[1] if v[2] is not None: alpha *= v[2] / mat2.shape[0] - calc = (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())) - if len(weight.shape) > 2: - calc = calc.reshape(weight.shape) - weight[index * mat1.shape[0]:(index + 1) * mat1.shape[0]] += calc.type(weight.dtype).to(weight.device) + weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): model_sd = self.model.state_dict() diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 52cecb32..351d920a 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -1,86 +1,22 @@ import sd1_clip -import open_clip import torch +import os -class SD2ClipModel(torch.nn.Module, sd1_clip.ClipTokenWeightEncoder): - """ - Uses the OpenCLIP transformer encoder for text - """ - LAYERS = [ - #"pooled", - "last", - "penultimate", - "hidden" - ] - #version="laion2b_s32b_b79k" - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, - freeze=True, layer="penultimate", layer_idx=None): - super().__init__() - assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu')) - del model.visual - self.model = model - - self.device = device - self.max_length = max_length +class SD2ClipModel(sd1_clip.SD1ClipModel): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) self.empty_tokens = [[49406] + [49407] + [0] * 75] - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 + if layer == "last": + layer_idx = -1 + elif layer == "penultimate": + layer_idx = -2 elif self.layer == "hidden": assert layer_idx is not None assert abs(layer_idx) < 24 - self.clip_layer(layer_idx) else: raise NotImplementedError() - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def clip_layer(self, layer_idx): - #layer_idx should have the same logic as the one for SD1 - if abs(layer_idx) >= 24: - self.layer_idx = 0 - else: - if layer_idx < 0: - self.layer_idx = -(layer_idx + 1) - else: - self.layer_idx = 24 - (layer_idx + 1) - - def forward(self, tokens): - tokens = torch.LongTensor(tokens).to(self.device) - z = self.encode_with_transformer(tokens) - return z - - def encode_with_transformer(self, tokens): - x = self.model.token_embedding(tokens) # [batch_size, n_ctx, d_model] - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.model.ln_final(x) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - self.layer_idx: - break - if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - def encode(self, tokens): - return self(tokens) - - + self.clip_layer(layer_idx) class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None): diff --git a/comfy/sd2_clip_config.json b/comfy/sd2_clip_config.json new file mode 100644 index 00000000..ace6ef00 --- /dev/null +++ b/comfy/sd2_clip_config.json @@ -0,0 +1,23 @@ +{ + "architectures": [ + "CLIPTextModel" + ], + "attention_dropout": 0.0, + "bos_token_id": 0, + "dropout": 0.0, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_size": 1024, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 77, + "model_type": "clip_text_model", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "pad_token_id": 1, + "projection_dim": 512, + "torch_dtype": "float32", + "vocab_size": 49408 +}