Add some low vram modes: --lowvram and --novram
This commit is contained in:
parent
a84cd0d1ad
commit
534736b924
|
@ -66,7 +66,7 @@ class DiscreteSchedule(nn.Module):
|
|||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
if quantize:
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
|
|
|
@ -1,11 +1,48 @@
|
|||
|
||||
CPU = 0
|
||||
NO_VRAM = 1
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
|
||||
accelerate_enabled = False
|
||||
vram_state = NORMAL_VRAM
|
||||
|
||||
import sys
|
||||
|
||||
set_vram_to = NORMAL_VRAM
|
||||
if "--lowvram" in sys.argv:
|
||||
set_vram_to = LOW_VRAM
|
||||
if "--novram" in sys.argv:
|
||||
set_vram_to = NO_VRAM
|
||||
|
||||
if set_vram_to != NORMAL_VRAM:
|
||||
try:
|
||||
import accelerate
|
||||
accelerate_enabled = True
|
||||
vram_state = set_vram_to
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
|
||||
|
||||
|
||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
|
||||
|
||||
model_accelerated = False
|
||||
|
||||
|
||||
def unload_model():
|
||||
global current_loaded_model
|
||||
global model_accelerated
|
||||
if current_loaded_model is not None:
|
||||
if model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
model_accelerated = False
|
||||
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
|
@ -13,6 +50,9 @@ def unload_model():
|
|||
|
||||
def load_model_gpu(model):
|
||||
global current_loaded_model
|
||||
global vram_state
|
||||
global model_accelerated
|
||||
|
||||
if model is current_loaded_model:
|
||||
return
|
||||
unload_model()
|
||||
|
@ -22,5 +62,16 @@ def load_model_gpu(model):
|
|||
model.unpatch_model()
|
||||
raise e
|
||||
current_loaded_model = model
|
||||
if vram_state == CPU:
|
||||
pass
|
||||
elif vram_state == NORMAL_VRAM:
|
||||
model_accelerated = False
|
||||
real_model.cuda()
|
||||
else:
|
||||
if vram_state == NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_state == LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "1GiB", "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
|
3
main.py
3
main.py
|
@ -14,6 +14,9 @@ if __name__ == "__main__":
|
|||
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
|
||||
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
|
||||
print()
|
||||
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
|
||||
print("\t--novram\t\t\tWhen lowvram isn't enough.")
|
||||
print()
|
||||
exit()
|
||||
|
||||
if '--dont-upcast-attention' in sys.argv:
|
||||
|
|
|
@ -8,3 +8,5 @@ transformers
|
|||
safetensors
|
||||
pytorch_lightning
|
||||
|
||||
accelerate
|
||||
|
||||
|
|
Loading…
Reference in New Issue