diff --git a/comfy/model_management.py b/comfy/model_management.py index 74958908..11c97f29 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -319,16 +319,14 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(loaded_model, unload_weights_only=True): - model = loaded_model.model - +def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload if len(to_unload) == 0: - return + return None same_weights = 0 for i in to_unload: @@ -340,14 +338,15 @@ def unload_model_clones(loaded_model, unload_weights_only=True): else: unload_weight = True - if unload_weights_only and unload_weight == False: - return + if not force_unload: + if unload_weights_only and unload_weight == False: + return None for i in to_unload: logging.debug("unload clone {} {}".format(i, unload_weight)) current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) - loaded_model.weights_loaded = not unload_weight + return unload_weight def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -402,7 +401,7 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model, unload_weights_only=True) #unload clones where the weights are different + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: @@ -410,7 +409,9 @@ def load_models_gpu(models, memory_required=0): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) for loaded_model in models_to_load: - unload_model_clones(loaded_model, unload_weights_only=False) #unload the rest of the clones where the weights can stay loaded + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded for loaded_model in models_to_load: model = loaded_model.model