From 5e6bc824aaa673813da6177832df26a247e1d7f3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 3 Jul 2023 15:45:04 -0400 Subject: [PATCH] Allow passing custom path to clip-g and clip-h. --- comfy/sd2_clip.py | 4 ++-- comfy/sdxl_clip.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 1b43fdc1..3308e525 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -3,9 +3,9 @@ import torch import os class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] if layer == "last": pass diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index f251168d..c768b9f9 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,9 +3,9 @@ import torch import os class SDXLClipG(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.layer_norm_hidden_state = False