diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 1b43fdc1..3308e525 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -3,9 +3,9 @@ import torch import os class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] if layer == "last": pass diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index f251168d..c768b9f9 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,9 +3,9 @@ import torch import os class SDXLClipG(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.layer_norm_hidden_state = False