diff --git a/comfy/sd.py b/comfy/sd.py index 0cd75833..125b15b7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -493,6 +493,8 @@ class CLIP: def encode_from_tokens(self, tokens, return_pooled=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) + else: + self.cond_stage_model.reset_clip_layer() model_management.load_model_gpu(self.patcher) cond, pooled = self.cond_stage_model.encode_token_weights(tokens) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c008e963..d504bf77 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -46,12 +46,14 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS + self.num_layers = 12 if textmodel_path is not None: self.transformer = CLIPTextModel.from_pretrained(textmodel_path) else: if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") config = CLIPTextConfig.from_json_file(textmodel_json_config) + self.num_layers = config.num_hidden_layers with comfy.ops.use_comfy_ops(): with modeling_utils.no_init_weights(): self.transformer = CLIPTextModel(config) @@ -66,8 +68,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = True if layer == "hidden": assert layer_idx is not None - assert abs(layer_idx) <= 12 + assert abs(layer_idx) <= self.num_layers self.clip_layer(layer_idx) + self.layer_default = (self.layer, self.layer_idx) def freeze(self): self.transformer = self.transformer.eval() @@ -76,12 +79,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): param.requires_grad = False def clip_layer(self, layer_idx): - if abs(layer_idx) >= 12: + if abs(layer_idx) >= self.num_layers: self.layer = "last" else: self.layer = "hidden" self.layer_idx = layer_idx + def reset_clip_layer(self): + self.layer = self.layer_default[0] + self.layer_idx = self.layer_default[1] + def set_up_textual_embeddings(self, tokens, current_embeds): out_tokens = [] next_new_token = token_dict_size = current_embeds.weight.shape[0] diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 3308e525..1ffe31b6 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -4,20 +4,13 @@ import os class SD2ClipModel(sd1_clip.SD1ClipModel): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): + if layer == "penultimate": + layer="hidden" + layer_idx=23 + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] - if layer == "last": - pass - elif layer == "penultimate": - layer_idx = -1 - self.clip_layer(layer_idx) - elif self.layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < 24 - self.clip_layer(layer_idx) - else: - raise NotImplementedError() def clip_layer(self, layer_idx): if layer_idx < 0: diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index d0803b10..65d2bb20 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -4,33 +4,16 @@ import os class SDXLClipG(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): + if layer == "penultimate": + layer="hidden" + layer_idx=-2 + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.layer_norm_hidden_state = False - if layer == "last": - pass - elif layer == "penultimate": - layer_idx = -1 - self.clip_layer(layer_idx) - elif self.layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < 32 - self.clip_layer(layer_idx) - else: - raise NotImplementedError() - - def clip_layer(self, layer_idx): - if layer_idx < 0: - layer_idx -= 1 #The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage. - if abs(layer_idx) >= 32: - self.layer = "hidden" - self.layer_idx = -2 - else: - self.layer = "hidden" - self.layer_idx = layer_idx def load_sd(self, sd): if "text_projection" in sd: @@ -69,6 +52,10 @@ class SDXLClipModel(torch.nn.Module): self.clip_l.clip_layer(layer_idx) self.clip_g.clip_layer(layer_idx) + def reset_clip_layer(self): + self.clip_g.reset_clip_layer() + self.clip_l.reset_clip_layer() + def encode_token_weights(self, token_weight_pairs): token_weight_pairs_g = token_weight_pairs["g"] token_weight_pairs_l = token_weight_pairs["l"] @@ -90,6 +77,9 @@ class SDXLRefinerClipModel(torch.nn.Module): def clip_layer(self, layer_idx): self.clip_g.clip_layer(layer_idx) + def reset_clip_layer(self): + self.clip_g.reset_clip_layer() + def encode_token_weights(self, token_weight_pairs): token_weight_pairs_g = token_weight_pairs["g"] g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)