Fix cached_hook_patches not respecting target_device/memory_counter results

This commit is contained in:
Jedrzej Kosinski 2024-11-17 12:52:40 -06:00
parent a20be20ac7
commit f465004a58
1 changed files with 12 additions and 9 deletions

View File

@ -981,7 +981,7 @@ class ModelPatcher:
with self.use_ejected():
self.unpatch_hooks()
if hooks is not None:
model_sd = self.model_state_dict()
model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
@ -991,7 +991,7 @@ class ModelPatcher:
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
for key in cached_weights:
if key not in model_sd:
if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
@ -1001,7 +1001,7 @@ class ModelPatcher:
if len(relevant_patches) > 0:
original_weights = self.get_key_patches()
for key in relevant_patches:
if key not in model_sd:
if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
@ -1018,9 +1018,9 @@ class ModelPatcher:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
if self.weight_inplace_update:
comfy.utils.copy_to_param(self.model, key, cached_weights[key])
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
else:
comfy.utils.set_attr_param(self.model, key, cached_weights[key])
comfy.utils.set_attr_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
def clear_cached_hook_weights(self):
self.cached_hook_patches.clear()
@ -1037,19 +1037,22 @@ class ModelPatcher:
if used:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
# TODO: properly handle lowvram situations for cached hook patches
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
out_weight = comfy.lora.calculate_weight(combined_patches[key], temp_weight, key, original_weights=original_weights).to(weight.dtype)
out_weight = comfy.lora.calculate_weight(combined_patches[key],
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
key, original_weights=original_weights).to(weight.dtype)
del original_weights[key]
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
self.cached_hook_patches.setdefault(hooks, {})
self.cached_hook_patches[hooks][key] = out_weight
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
if self.weight_inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
del weight
del out_weight
def unpatch_hooks(self) -> None:
with self.use_ejected():