Add: --highvram for when you want models to stay on the vram.
This commit is contained in:
parent
09f1d76ed8
commit
2326ff1263
|
@ -3,6 +3,7 @@ CPU = 0
|
|||
NO_VRAM = 1
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
|
||||
accelerate_enabled = False
|
||||
vram_state = NORMAL_VRAM
|
||||
|
@ -27,10 +28,11 @@ if "--lowvram" in sys.argv:
|
|||
set_vram_to = LOW_VRAM
|
||||
if "--novram" in sys.argv:
|
||||
set_vram_to = NO_VRAM
|
||||
if "--highvram" in sys.argv:
|
||||
vram_state = HIGH_VRAM
|
||||
|
||||
|
||||
|
||||
if set_vram_to != NORMAL_VRAM:
|
||||
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
||||
try:
|
||||
import accelerate
|
||||
accelerate_enabled = True
|
||||
|
@ -44,7 +46,7 @@ if set_vram_to != NORMAL_VRAM:
|
|||
total_vram_available_mb = int(max(256, total_vram_available_mb))
|
||||
|
||||
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state])
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
|
@ -57,18 +59,24 @@ def unload_model():
|
|||
global current_loaded_model
|
||||
global model_accelerated
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
|
||||
if current_loaded_model is not None:
|
||||
if model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
model_accelerated = False
|
||||
|
||||
current_loaded_model.model.cpu()
|
||||
#never unload models from GPU on high vram
|
||||
if vram_state != HIGH_VRAM:
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
if len(current_gpu_controlnets) > 0:
|
||||
for n in current_gpu_controlnets:
|
||||
n.cpu()
|
||||
current_gpu_controlnets = []
|
||||
|
||||
if vram_state != HIGH_VRAM:
|
||||
if len(current_gpu_controlnets) > 0:
|
||||
for n in current_gpu_controlnets:
|
||||
n.cpu()
|
||||
current_gpu_controlnets = []
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
|
@ -87,7 +95,7 @@ def load_model_gpu(model):
|
|||
current_loaded_model = model
|
||||
if vram_state == CPU:
|
||||
pass
|
||||
elif vram_state == NORMAL_VRAM:
|
||||
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
||||
model_accelerated = False
|
||||
real_model.cuda()
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue