Unload models and load them back in lowvram mode no free vram.

This commit is contained in:
comfyanonymous 2024-08-06 03:22:39 -04:00
parent 2894511893
commit c14ac98fed
1 changed files with 10 additions and 2 deletions

View File

@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model.append(i) unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i) unloaded_models.append(current_loaded_models.pop(i))
if len(unloaded_model) > 0: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache() soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None): def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
global vram_state global vram_state
@ -421,7 +423,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for d in devs: for d in devs:
if d != torch.device("cpu"): if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
return free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
if len(models_to_load) == 0:
return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")