Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
This commit is contained in:
parent
f465004a58
commit
e844695292
|
@ -990,6 +990,7 @@ class ModelPatcher:
|
||||||
# if have cached weights for hooks, use it
|
# if have cached weights for hooks, use it
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
|
with torch.no_grad():
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||||
|
@ -1000,6 +1001,7 @@ class ModelPatcher:
|
||||||
original_weights = None
|
original_weights = None
|
||||||
if len(relevant_patches) > 0:
|
if len(relevant_patches) > 0:
|
||||||
original_weights = self.get_key_patches()
|
original_weights = self.get_key_patches()
|
||||||
|
with torch.no_grad():
|
||||||
for key in relevant_patches:
|
for key in relevant_patches:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||||
|
@ -1016,11 +1018,8 @@ class ModelPatcher:
|
||||||
used = memory_counter.use(weight)
|
used = memory_counter.use(weight)
|
||||||
if used:
|
if used:
|
||||||
target_device = weight.device
|
target_device = weight.device
|
||||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||||
if self.weight_inplace_update:
|
|
||||||
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
||||||
else:
|
|
||||||
comfy.utils.set_attr_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
|
||||||
|
|
||||||
def clear_cached_hook_weights(self):
|
def clear_cached_hook_weights(self):
|
||||||
self.cached_hook_patches.clear()
|
self.cached_hook_patches.clear()
|
||||||
|
@ -1036,7 +1035,7 @@ class ModelPatcher:
|
||||||
used = memory_counter.use(weight)
|
used = memory_counter.use(weight)
|
||||||
if used:
|
if used:
|
||||||
target_device = weight.device
|
target_device = weight.device
|
||||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
||||||
# TODO: properly handle lowvram situations for cached hook patches
|
# TODO: properly handle lowvram situations for cached hook patches
|
||||||
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
||||||
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
|
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
|
||||||
|
@ -1045,27 +1044,25 @@ class ModelPatcher:
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
# TODO: disable caching if not enough system RAM to do so
|
# TODO: disable caching if not enough system RAM to do so
|
||||||
|
target_device = self.offload_device
|
||||||
|
used = memory_counter.use(weight)
|
||||||
|
if used:
|
||||||
|
target_device = weight.device
|
||||||
self.cached_hook_patches.setdefault(hooks, {})
|
self.cached_hook_patches.setdefault(hooks, {})
|
||||||
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
||||||
if self.weight_inplace_update:
|
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
|
||||||
del weight
|
del weight
|
||||||
del out_weight
|
del out_weight
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
return
|
return
|
||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
if self.weight_inplace_update:
|
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
else:
|
|
||||||
for k in keys:
|
|
||||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
Loading…
Reference in New Issue