From 5d8898c05668b5504f8ad5bc79779381d0af35b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Mar 2024 18:01:04 -0400 Subject: [PATCH] Fix some performance issues with weight loading and unloading. Lower peak memory usage when changing model. Fix case where model weights would be unloaded and reloaded. --- comfy/model_management.py | 16 +++++++++++----- execution.py | 1 + 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 715ca2ee..26216432 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -274,6 +274,7 @@ class LoadedModel: self.model = model self.device = model.load_device self.weights_loaded = False + self.real_model = None def model_memory(self): return self.model.model_size() @@ -312,6 +313,7 @@ class LoadedModel: self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) self.weights_loaded = self.weights_loaded and not unpatch_weights + self.real_model = None def __eq__(self, other): return self.model is other.model @@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [i] + to_unload if len(to_unload) == 0: - return None + return True same_weights = 0 for i in to_unload: @@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): @@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0): def load_model_gpu(model): return load_models_gpu([model]) -def cleanup_models(): +def cleanup_models(keep_clone_weights_loaded=False): to_delete = [] for i in range(len(current_loaded_models)): if sys.getrefcount(current_loaded_models[i].model) <= 2: - to_delete = [i] + to_delete + if not keep_clone_weights_loaded: + to_delete = [i] + to_delete + #TODO: find a less fragile way to do this. + elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model + to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) diff --git a/execution.py b/execution.py index 1b8f606a..35bdb77a 100644 --- a/execution.py +++ b/execution.py @@ -368,6 +368,7 @@ class PromptExecutor: d = self.outputs_ui.pop(x) del d + comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False)