Fix some issues with inference slowing down.
This commit is contained in:
parent
ae197f651b
commit
1de69fe4d5
|
@ -296,7 +296,7 @@ class LoadedModel:
|
||||||
|
|
||||||
def model_memory_required(self, device):
|
def model_memory_required(self, device):
|
||||||
if device == self.model.current_loaded_device():
|
if device == self.model.current_loaded_device():
|
||||||
return 0
|
return self.model_offloaded_memory()
|
||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
|
@ -308,15 +308,21 @@ class LoadedModel:
|
||||||
|
|
||||||
load_weights = not self.weights_loaded
|
load_weights = not self.weights_loaded
|
||||||
|
|
||||||
try:
|
if self.model.loaded_size() > 0:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
use_more_vram = lowvram_model_memory
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
if use_more_vram == 0:
|
||||||
else:
|
use_more_vram = 1e32
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
self.model_use_more_vram(use_more_vram)
|
||||||
except Exception as e:
|
else:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
try:
|
||||||
self.model_unload()
|
if lowvram_model_memory > 0 and load_weights:
|
||||||
raise e
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
else:
|
||||||
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||||
|
except Exception as e:
|
||||||
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
|
self.model_unload()
|
||||||
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||||
|
@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
for device in total_memory_required:
|
for loaded_model in models_already_loaded:
|
||||||
if device != torch.device("cpu"):
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||||
if weights_unloaded is not None:
|
if weights_unloaded is not None:
|
||||||
loaded_model.weights_loaded = not weights_unloaded
|
loaded_model.weights_loaded = not weights_unloaded
|
||||||
|
|
||||||
|
for device in total_memory_required:
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
torch_dev = model.load_device
|
torch_dev = model.load_device
|
||||||
|
|
|
@ -102,7 +102,7 @@ class ModelPatcher:
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
if not hasattr(self.model, 'device'):
|
if not hasattr(self.model, 'device'):
|
||||||
logging.info("Model doesn't have a device attribute.")
|
logging.debug("Model doesn't have a device attribute.")
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
elif self.model.device is None:
|
elif self.model.device is None:
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
|
|
Loading…
Reference in New Issue