Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches

This commit is contained in:
Jedrzej Kosinski 2024-11-18 07:25:53 -06:00
parent f465004a58
commit e844695292
1 changed files with 25 additions and 28 deletions

View File

@ -990,6 +990,7 @@ class ModelPatcher:
# if have cached weights for hooks, use it # if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None) cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None: if cached_weights is not None:
with torch.no_grad():
for key in cached_weights: for key in cached_weights:
if key not in model_sd_keys: if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}") print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
@ -1000,6 +1001,7 @@ class ModelPatcher:
original_weights = None original_weights = None
if len(relevant_patches) > 0: if len(relevant_patches) > 0:
original_weights = self.get_key_patches() original_weights = self.get_key_patches()
with torch.no_grad():
for key in relevant_patches: for key in relevant_patches:
if key not in model_sd_keys: if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}") print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
@ -1016,11 +1018,8 @@ class ModelPatcher:
used = memory_counter.use(weight) used = memory_counter.use(weight)
if used: if used:
target_device = weight.device target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device) self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
if self.weight_inplace_update:
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1])) comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
else:
comfy.utils.set_attr_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
def clear_cached_hook_weights(self): def clear_cached_hook_weights(self):
self.cached_hook_patches.clear() self.cached_hook_patches.clear()
@ -1036,7 +1035,7 @@ class ModelPatcher:
used = memory_counter.use(weight) used = memory_counter.use(weight)
if used: if used:
target_device = weight.device target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device) self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
# TODO: properly handle lowvram situations for cached hook patches # TODO: properly handle lowvram situations for cached hook patches
out_weight = comfy.lora.calculate_weight(combined_patches[key], out_weight = comfy.lora.calculate_weight(combined_patches[key],
comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True), comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True),
@ -1045,27 +1044,25 @@ class ModelPatcher:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so # TODO: disable caching if not enough system RAM to do so
target_device = self.offload_device
used = memory_counter.use(weight)
if used:
target_device = weight.device
self.cached_hook_patches.setdefault(hooks, {}) self.cached_hook_patches.setdefault(hooks, {})
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=self.weight_inplace_update), weight.device) self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
if self.weight_inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight) comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
del weight del weight
del out_weight del out_weight
@torch.no_grad()
def unpatch_hooks(self) -> None: def unpatch_hooks(self) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None
return return
keys = list(self.hook_backup.keys()) keys = list(self.hook_backup.keys())
if self.weight_inplace_update:
for k in keys: for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear() self.hook_backup.clear()
self.current_hooks = None self.current_hooks = None