Fix issue with lowvram mode breaking model saving.
This commit is contained in:
parent
4f63ee99f1
commit
e1489ad257
|
@ -285,7 +285,7 @@ class LoadedModel:
|
|||
else:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0):
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
|
@ -295,7 +295,7 @@ class LoadedModel:
|
|||
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
|
@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
|
@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0):
|
|||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
|
|
|
@ -272,7 +272,7 @@ class ModelPatcher:
|
|||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||
|
@ -296,9 +296,15 @@ class ModelPatcher:
|
|||
|
||||
if lowvram_weight:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
|
|
|
@ -562,7 +562,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models)
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
|
|
|
@ -262,7 +262,7 @@ class CLIPSave:
|
|||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()])
|
||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||
|
|
Loading…
Reference in New Issue