Fix some performance issues with weight loading and unloading.

Lower peak memory usage when changing model.

Fix case where model weights would be unloaded and reloaded.
This commit is contained in:
comfyanonymous 2024-03-28 18:01:04 -04:00
parent 327ca1313d
commit 5d8898c056
2 changed files with 12 additions and 5 deletions

View File

@ -274,6 +274,7 @@ class LoadedModel:
self.model = model self.model = model
self.device = model.load_device self.device = model.load_device
self.weights_loaded = False self.weights_loaded = False
self.real_model = None
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
@ -312,6 +313,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload to_unload = [i] + to_unload
if len(to_unload) == 0: if len(to_unload) == 0:
return None return True
same_weights = 0 same_weights = 0
for i in to_unload: for i in to_unload:
@ -408,7 +410,7 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for device in total_memory_required:
@ -448,10 +450,14 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model): def load_model_gpu(model):
return load_models_gpu([model]) return load_models_gpu([model])
def cleanup_models(): def cleanup_models(keep_clone_weights_loaded=False):
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:

View File

@ -368,6 +368,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id}, { "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False) broadcast=False)