Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
This commit is contained in:
parent
f465004a58
commit
e844695292
|
@ -990,6 +990,7 @@ class ModelPatcher:
|
|||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
with torch.no_grad():
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
|
@ -1000,6 +1001,7 @@ class ModelPatcher:
|
|||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
original_weights = self.get_key_patches()
|
||||
with torch.no_grad():
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
|
@ -1016,11 +1018,8 @@ class ModelPatcher:
|
|||
used = memory_counter.use(weight)
|
||||
if used:
|
||||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
if self.weight_inplace_update:
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
||||
|
||||
def clear_cached_hook_weights(self):
|
||||
self.cached_hook_patches.clear()
|
||||
|
@ -1036,7 +1035,7 @@ class ModelPatcher:
|
|||
used = memory_counter.use(weight)
|
||||
if used:
|
||||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||
# TODO: properly handle lowvram situations for cached hook patches
|
||||
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
||||
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
|
||||
|
@ -1045,27 +1044,25 @@ class ModelPatcher:
|
|||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
target_device = self.offload_device
|
||||
used = memory_counter.use(weight)
|
||||
if used:
|
||||
target_device = weight.device
|
||||
self.cached_hook_patches.setdefault(hooks, {})
|
||||
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
if self.weight_inplace_update:
|
||||
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
del weight
|
||||
del out_weight
|
||||
|
||||
@torch.no_grad()
|
||||
def unpatch_hooks(self) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
self.current_hooks = None
|
||||
return
|
||||
keys = list(self.hook_backup.keys())
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
|
Loading…
Reference in New Issue