Fix issue with lora in some cases when combined with model merging.

This commit is contained in:
comfyanonymous 2023-07-21 21:27:27 -04:00
parent 58b2364f58
commit 09386a3697
2 changed files with 11 additions and 10 deletions

View File

@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
return key_map return key_map
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0):
self.size = size self.size = size
@ -340,10 +348,11 @@ class ModelPatcher:
weight = model_sd[key] weight = model_sd[key]
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(self.offload_device, copy=True) self.backup[key] = weight.to(self.offload_device)
temp_weight = weight.to(torch.float32, copy=True) temp_weight = weight.to(torch.float32, copy=True)
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del temp_weight del temp_weight
return self.model return self.model
@ -439,13 +448,6 @@ class ModelPatcher:
def unpatch_model(self): def unpatch_model(self):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
for k in keys: for k in keys:
set_attr(self.model, k, self.backup[k]) set_attr(self.model, k, self.backup[k])

View File

@ -6,7 +6,6 @@ import threading
import heapq import heapq
import traceback import traceback
import gc import gc
import time
import torch import torch
import nodes import nodes