Fix some issues with inference slowing down.

This commit is contained in:
comfyanonymous 2024-08-10 15:29:36 -04:00
parent ae197f651b
commit 1de69fe4d5
2 changed files with 25 additions and 16 deletions

View File

@ -296,7 +296,7 @@ class LoadedModel:
def model_memory_required(self, device): def model_memory_required(self, device):
if device == self.model.current_loaded_device(): if device == self.model.current_loaded_device():
return 0 return self.model_offloaded_memory()
else: else:
return self.model_memory() return self.model_memory()
@ -308,6 +308,12 @@ class LoadedModel:
load_weights = not self.weights_loaded load_weights = not self.weights_loaded
if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram)
else:
try: try:
if lowvram_model_memory > 0 and load_weights: if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for loaded_model in models_already_loaded:
if device != torch.device("cpu"): total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load: for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None: if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load: for loaded_model in models_to_load:
model = loaded_model.model model = loaded_model.model
torch_dev = model.load_device torch_dev = model.load_device

View File

@ -102,7 +102,7 @@ class ModelPatcher:
self.size = size self.size = size
self.model = model self.model = model
if not hasattr(self.model, 'device'): if not hasattr(self.model, 'device'):
logging.info("Model doesn't have a device attribute.") logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device self.model.device = offload_device
elif self.model.device is None: elif self.model.device is None:
self.model.device = offload_device self.model.device = offload_device