Fix issue with lora in some cases when combined with model merging.
This commit is contained in:
parent
58b2364f58
commit
09386a3697
20
comfy/sd.py
20
comfy/sd.py
|
@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
def set_attr(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
prev = getattr(obj, attrs[-1])
|
||||||
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||||
|
del prev
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0):
|
def __init__(self, model, load_device, offload_device, size=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
@ -340,10 +348,11 @@ class ModelPatcher:
|
||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(self.offload_device, copy=True)
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
set_attr(self.model, key, out_weight)
|
||||||
del temp_weight
|
del temp_weight
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -439,13 +448,6 @@ class ModelPatcher:
|
||||||
|
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
def set_attr(obj, attr, value):
|
|
||||||
attrs = attr.split(".")
|
|
||||||
for name in attrs[:-1]:
|
|
||||||
obj = getattr(obj, name)
|
|
||||||
prev = getattr(obj, attrs[-1])
|
|
||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
|
||||||
del prev
|
|
||||||
|
|
||||||
for k in keys:
|
for k in keys:
|
||||||
set_attr(self.model, k, self.backup[k])
|
set_attr(self.model, k, self.backup[k])
|
||||||
|
|
|
@ -6,7 +6,6 @@ import threading
|
||||||
import heapq
|
import heapq
|
||||||
import traceback
|
import traceback
|
||||||
import gc
|
import gc
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
|
|
Loading…
Reference in New Issue