Fix some performance issues with weight loading and unloading.

Lower peak memory usage when changing model.

Fix case where model weights would be unloaded and reloaded.
This commit is contained in:
comfyanonymous 2024-03-28 18:01:04 -04:00
parent 327ca1313d
commit 5d8898c056
2 changed files with 12 additions and 5 deletions

View File

@ -274,6 +274,7 @@ class LoadedModel:
self.model = model
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
def model_memory(self):
return self.model.model_size()
@ -312,6 +313,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other):
return self.model is other.model
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
return True
same_weights = 0
for i in to_unload:
@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model):
return load_models_gpu([model])
def cleanup_models():
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)

View File

@ -368,6 +368,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)