Change torch.xpu to ipex.optimize, xpu device initialization and remove workaround for text node issue from older IPEX. (#3388)
This commit is contained in:
parent
f81a6fade8
commit
a56d02efc7
|
@ -83,7 +83,7 @@ def get_torch_device():
|
|||
return torch.device("cpu")
|
||||
else:
|
||||
if is_intel_xpu():
|
||||
return torch.device("xpu")
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
|
@ -304,7 +304,7 @@ class LoadedModel:
|
|||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
|
@ -552,8 +552,6 @@ def text_encoder_device():
|
|||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||
if is_intel_xpu():
|
||||
return torch.device("cpu")
|
||||
if should_use_fp16(prioritize_performance=False):
|
||||
return get_torch_device()
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue