Fix issue with lowvram mode breaking model saving.

This commit is contained in:
comfyanonymous 2024-05-11 21:46:05 -04:00
parent 4f63ee99f1
commit e1489ad257
4 changed files with 15 additions and 9 deletions

View File

@ -285,7 +285,7 @@ class LoadedModel:
else:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device)
@ -295,7 +295,7 @@ class LoadedModel:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state
inference_memory = minimum_inference_memory()
@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0):
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return

View File

@ -272,7 +272,7 @@ class ModelPatcher:
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
@ -296,9 +296,15 @@ class ModelPatcher:
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True

View File

@ -562,7 +562,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:

View File

@ -262,7 +262,7 @@ class CLIPSave:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]: