Remove unnecessary torch.no_grad calls for hook patches
This commit is contained in:
parent
e844695292
commit
0850ae5c04
|
@ -990,23 +990,21 @@ class ModelPatcher:
|
|||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
with torch.no_grad():
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
else:
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
if len(relevant_patches) > 0:
|
||||
original_weights = self.get_key_patches()
|
||||
with torch.no_grad():
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
self.current_hooks = hooks
|
||||
|
||||
|
@ -1054,7 +1052,6 @@ class ModelPatcher:
|
|||
del weight
|
||||
del out_weight
|
||||
|
||||
@torch.no_grad()
|
||||
def unpatch_hooks(self) -> None:
|
||||
with self.use_ejected():
|
||||
if len(self.hook_backup) == 0:
|
||||
|
|
Loading…
Reference in New Issue