Unload models and load them back in lowvram mode no free vram.
This commit is contained in:
parent
2894511893
commit
c14ac98fed
|
@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
|
unloaded_models = []
|
||||||
|
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
|
@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
current_loaded_models.pop(i)
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
@ -421,6 +423,12 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||||
for d in devs:
|
for d in devs:
|
||||||
if d != torch.device("cpu"):
|
if d != torch.device("cpu"):
|
||||||
free_memory(extra_mem, d, models_already_loaded)
|
free_memory(extra_mem, d, models_already_loaded)
|
||||||
|
free_mem = get_free_memory(d)
|
||||||
|
if free_mem < minimum_memory_required:
|
||||||
|
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||||
|
models_to_load = free_memory(minimum_memory_required, d)
|
||||||
|
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||||
|
if len(models_to_load) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||||
|
|
Loading…
Reference in New Issue