Fix issue with lora in some cases when combined with model merging.
This commit is contained in:
parent
58b2364f58
commit
09386a3697
20
comfy/sd.py
20
comfy/sd.py
|
@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}):
|
|||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
return key_map
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||
del prev
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0):
|
||||
self.size = size
|
||||
|
@ -340,10 +348,11 @@ class ModelPatcher:
|
|||
weight = model_sd[key]
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(self.offload_device, copy=True)
|
||||
self.backup[key] = weight.to(self.offload_device)
|
||||
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
return self.model
|
||||
|
||||
|
@ -439,13 +448,6 @@ class ModelPatcher:
|
|||
|
||||
def unpatch_model(self):
|
||||
keys = list(self.backup.keys())
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||
del prev
|
||||
|
||||
for k in keys:
|
||||
set_attr(self.model, k, self.backup[k])
|
||||
|
|
|
@ -6,7 +6,6 @@ import threading
|
|||
import heapq
|
||||
import traceback
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
|
|
Loading…
Reference in New Issue