diff --git a/comfy/sd.py b/comfy/sd.py index 7a079daa..61b59383 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}): key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) return key_map +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0): self.size = size @@ -340,10 +348,11 @@ class ModelPatcher: weight = model_sd[key] if key not in self.backup: - self.backup[key] = weight.to(self.offload_device, copy=True) + self.backup[key] = weight.to(self.offload_device) temp_weight = weight.to(torch.float32, copy=True) - weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + set_attr(self.model, key, out_weight) del temp_weight return self.model @@ -439,13 +448,6 @@ class ModelPatcher: def unpatch_model(self): keys = list(self.backup.keys()) - def set_attr(obj, attr, value): - attrs = attr.split(".") - for name in attrs[:-1]: - obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) - del prev for k in keys: set_attr(self.model, k, self.backup[k]) diff --git a/execution.py b/execution.py index a40b1dd3..f19d0b23 100644 --- a/execution.py +++ b/execution.py @@ -6,7 +6,6 @@ import threading import heapq import traceback import gc -import time import torch import nodes