Fix last commits causing an issue with the text encoder lora.
This commit is contained in:
parent
bf3f271775
commit
51581dbfa9
|
@ -357,11 +357,12 @@ class ModelPatcher:
|
||||||
self.patches += [(strength_patch, p, strength_model)]
|
self.patches += [(strength_patch, p, strength_model)]
|
||||||
return p.keys()
|
return p.keys()
|
||||||
|
|
||||||
def model_state_dict(self):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
sd = self.model.state_dict()
|
sd = self.model.state_dict()
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
|
if filter_prefix is not None:
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if not k.startswith("diffusion_model."):
|
if not k.startswith(filter_prefix):
|
||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
@ -443,7 +444,7 @@ class ModelPatcher:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
return self.model
|
return self.model
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model_state_dict()
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
model_sd[k][:] = self.backup[k]
|
model_sd[k][:] = self.backup[k]
|
||||||
|
|
|
@ -14,7 +14,7 @@ class ModelMergeSimple:
|
||||||
|
|
||||||
def merge(self, model1, model2, ratio):
|
def merge(self, model1, model2, ratio):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
sd = model2.model_state_dict()
|
sd = model2.model_state_dict("diffusion_model.")
|
||||||
for k in sd:
|
for k in sd:
|
||||||
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
@ -35,7 +35,7 @@ class ModelMergeBlocks:
|
||||||
|
|
||||||
def merge(self, model1, model2, **kwargs):
|
def merge(self, model1, model2, **kwargs):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
sd = model2.model_state_dict()
|
sd = model2.model_state_dict("diffusion_model.")
|
||||||
default_ratio = next(iter(kwargs.values()))
|
default_ratio = next(iter(kwargs.values()))
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
|
|
Loading…
Reference in New Issue