From c14ac98fedd0176686d285d384abec5e4c0140c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Aug 2024 03:22:39 -0400 Subject: [PATCH] Unload models and load them back in lowvram mode no free vram. --- comfy/model_management.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3d9ed525..cdbcd0be 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = [] can_unload = [] + unloaded_models = [] for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]): unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): - current_loaded_models.pop(i) + unloaded_models.append(current_loaded_models.pop(i)) if len(unloaded_model) > 0: soft_empty_cache() @@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]): mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() + return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None): global vram_state @@ -421,7 +423,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for d in devs: if d != torch.device("cpu"): free_memory(extra_mem, d, models_already_loaded) - return + free_mem = get_free_memory(d) + if free_mem < minimum_memory_required: + logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. + models_to_load = free_memory(minimum_memory_required, d) + logging.info("{} models unloaded.".format(len(models_to_load))) + if len(models_to_load) == 0: + return logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")