Allow passing custom path to clip-g and clip-h.

This commit is contained in:
comfyanonymous 2023-07-03 15:45:04 -04:00
parent dc9d1f31c8
commit 5e6bc824aa
2 changed files with 4 additions and 4 deletions

View File

@ -3,9 +3,9 @@ import torch
import os import os
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
self.empty_tokens = [[49406] + [49407] + [0] * 75] self.empty_tokens = [[49406] + [49407] + [0] * 75]
if layer == "last": if layer == "last":
pass pass

View File

@ -3,9 +3,9 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SD1ClipModel): class SDXLClipG(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
self.empty_tokens = [[49406] + [49407] + [0] * 75] self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
self.layer_norm_hidden_state = False self.layer_norm_hidden_state = False