Fix cached_hook_patches not respecting target_device/memory_counter results
This commit is contained in:
parent
a20be20ac7
commit
f465004a58
|
@ -981,7 +981,7 @@ class ModelPatcher:
|
|||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
if hooks is not None:
|
||||
model_sd = self.model_state_dict()
|
||||
model_sd_keys = list(self.model_state_dict().keys())
|
||||
memory_counter = None
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
||||
|
@ -991,7 +991,7 @@ class ModelPatcher:
|
|||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
for key in cached_weights:
|
||||
if key not in model_sd:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
|
@ -1001,7 +1001,7 @@ class ModelPatcher:
|
|||
if len(relevant_patches) > 0:
|
||||
original_weights = self.get_key_patches()
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
|
@ -1018,9 +1018,9 @@ class ModelPatcher:
|
|||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
if self.weight_inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, cached_weights[key])
|
||||
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, cached_weights[key])
|
||||
comfy.utils.set_attr_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
||||
|
||||
def clear_cached_hook_weights(self):
|
||||
self.cached_hook_patches.clear()
|
||||
|
@ -1037,19 +1037,22 @@ class ModelPatcher:
|
|||
if used:
|
||||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
|
||||
# TODO: properly handle lowvram situations for cached hook patches
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
|
||||
out_weight = comfy.lora.calculate_weight(combined_patches[key], temp_weight, key, original_weights=original_weights).to(weight.dtype)
|
||||
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
||||
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
|
||||
key, original_weights=original_weights).to(weight.dtype)
|
||||
del original_weights[key]
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
self.cached_hook_patches.setdefault(hooks, {})
|
||||
self.cached_hook_patches[hooks][key] = out_weight
|
||||
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
if self.weight_inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
del weight
|
||||
del out_weight
|
||||
|
||||
def unpatch_hooks(self) -> None:
|
||||
with self.use_ejected():
|
||||
|
|
Loading…
Reference in New Issue