Fix some performance issues with weight loading and unloading.
Lower peak memory usage when changing model. Fix case where model weights would be unloaded and reloaded.
This commit is contained in:
parent
327ca1313d
commit
5d8898c056
|
@ -274,6 +274,7 @@ class LoadedModel:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.weights_loaded = False
|
self.weights_loaded = False
|
||||||
|
self.real_model = None
|
||||||
|
|
||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
@ -312,6 +313,7 @@ class LoadedModel:
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||||
|
self.real_model = None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.model is other.model
|
return self.model is other.model
|
||||||
|
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = [i] + to_unload
|
to_unload = [i] + to_unload
|
||||||
|
|
||||||
if len(to_unload) == 0:
|
if len(to_unload) == 0:
|
||||||
return None
|
return True
|
||||||
|
|
||||||
same_weights = 0
|
same_weights = 0
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
|
@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
|
@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||||
to_delete = [i] + to_delete
|
if not keep_clone_weights_loaded:
|
||||||
|
to_delete = [i] + to_delete
|
||||||
|
#TODO: find a less fragile way to do this.
|
||||||
|
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||||
|
to_delete = [i] + to_delete
|
||||||
|
|
||||||
for i in to_delete:
|
for i in to_delete:
|
||||||
x = current_loaded_models.pop(i)
|
x = current_loaded_models.pop(i)
|
||||||
|
|
|
@ -368,6 +368,7 @@ class PromptExecutor:
|
||||||
d = self.outputs_ui.pop(x)
|
d = self.outputs_ui.pop(x)
|
||||||
del d
|
del d
|
||||||
|
|
||||||
|
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
|
|
Loading…
Reference in New Issue