Pass device to CLIP model.
This commit is contained in:
parent
5e6bc824aa
commit
c3e96e637d
|
@ -526,12 +526,13 @@ class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params
|
params = target.params.copy()
|
||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
|
params['device'] = load_device
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
#TODO: make sure this doesn't have a quality loss before enabling.
|
#TODO: make sure this doesn't have a quality loss before enabling.
|
||||||
# if model_management.should_use_fp16(load_device):
|
# if model_management.should_use_fp16(load_device):
|
||||||
|
|
Loading…
Reference in New Issue