Remove unnecessary torch.no_grad calls for hook patches

This commit is contained in:
Jedrzej Kosinski 2024-11-18 07:38:11 -06:00
parent e844695292
commit 0850ae5c04
1 changed files with 10 additions and 13 deletions

View File

@ -990,23 +990,21 @@ class ModelPatcher:
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
with torch.no_grad():
for key in cached_weights:
if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
for key in cached_weights:
if key not in model_sd_keys:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
original_weights = self.get_key_patches()
with torch.no_grad():
for key in relevant_patches:
if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
for key in relevant_patches:
if key not in model_sd_keys:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
self.current_hooks = hooks
@ -1054,7 +1052,6 @@ class ModelPatcher:
del weight
del out_weight
@torch.no_grad()
def unpatch_hooks(self) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0: