Add a LoraLoader node to apply loras to models and clip.
The models are modified in place before being used and unpatched after. I think this is better than monkeypatching since it might make it easier to use faster non pytorch unet inference in the future.
This commit is contained in:
parent
96664f5d5e
commit
ef90e9c376
186
comfy/sd.py
186
comfy/sd.py
|
@ -6,10 +6,7 @@ from ldm.util import instantiate_from_config
|
||||||
from ldm.models.autoencoder import AutoencoderKL
|
from ldm.models.autoencoder import AutoencoderKL
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
def load_torch_file(ckpt):
|
||||||
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
|
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
|
@ -21,6 +18,12 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
else:
|
else:
|
||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
|
||||||
|
sd = load_torch_file(ckpt)
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
@ -50,10 +53,160 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
LORA_CLIP_MAP = {
|
||||||
|
"mlp.fc1": "mlp_fc1",
|
||||||
|
"mlp.fc2": "mlp_fc2",
|
||||||
|
"self_attn.k_proj": "self_attn_k_proj",
|
||||||
|
"self_attn.q_proj": "self_attn_q_proj",
|
||||||
|
"self_attn.v_proj": "self_attn_v_proj",
|
||||||
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
LORA_UNET_MAP = {
|
||||||
|
"proj_in": "proj_in",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
|
||||||
|
"transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
|
||||||
|
"transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
|
||||||
|
"transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
|
||||||
|
"transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
|
||||||
|
"transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
|
||||||
|
"transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
|
||||||
|
"transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
|
||||||
|
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
|
||||||
|
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(path, to_load):
|
||||||
|
lora = load_torch_file(path)
|
||||||
|
patch_dict = {}
|
||||||
|
loaded_keys = set()
|
||||||
|
for x in to_load:
|
||||||
|
A_name = "{}.lora_up.weight".format(x)
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
if A_name in lora.keys():
|
||||||
|
alpha = None
|
||||||
|
if alpha_name in lora.keys():
|
||||||
|
alpha = lora[alpha_name].item()
|
||||||
|
loaded_keys.add(alpha_name)
|
||||||
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
for x in lora.keys():
|
||||||
|
if x not in loaded_keys:
|
||||||
|
print("lora key not loaded", x)
|
||||||
|
return patch_dict
|
||||||
|
|
||||||
|
def model_lora_keys(model, key_map={}):
|
||||||
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
|
||||||
|
up_counter = 0
|
||||||
|
for c in LORA_UNET_MAP:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
up_counter += 1
|
||||||
|
if up_counter >= 4:
|
||||||
|
counter += 1
|
||||||
|
for c in LORA_UNET_MAP:
|
||||||
|
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
counter = 3
|
||||||
|
for b in range(12):
|
||||||
|
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
||||||
|
up_counter = 0
|
||||||
|
for c in LORA_UNET_MAP:
|
||||||
|
k = "{}.{}.weight".format(tk, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
up_counter += 1
|
||||||
|
if up_counter >= 4:
|
||||||
|
counter += 1
|
||||||
|
counter = 0
|
||||||
|
for b in range(12):
|
||||||
|
for c in LORA_CLIP_MAP:
|
||||||
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
class ModelPatcher:
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
self.patches = []
|
||||||
|
self.backup = {}
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
n = ModelPatcher(self.model)
|
||||||
|
n.patches = self.patches[:]
|
||||||
|
return n
|
||||||
|
|
||||||
|
def add_patches(self, patches, strength=1.0):
|
||||||
|
p = {}
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
for k in patches:
|
||||||
|
if k in model_sd:
|
||||||
|
p[k] = patches[k]
|
||||||
|
self.patches += [(strength, p)]
|
||||||
|
return p.keys()
|
||||||
|
|
||||||
|
def patch_model(self):
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
for p in self.patches:
|
||||||
|
for k in p[1]:
|
||||||
|
v = p[1][k]
|
||||||
|
if k not in model_sd:
|
||||||
|
print("could not patch. key doesn't exist in model:", k)
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight = model_sd[k]
|
||||||
|
if k not in self.backup:
|
||||||
|
self.backup[k] = weight.clone()
|
||||||
|
|
||||||
|
alpha = p[0]
|
||||||
|
mat1 = v[0]
|
||||||
|
mat2 = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha *= v[2] / mat2.shape[0]
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
|
return self.model
|
||||||
|
def unpatch_model(self):
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
for k in self.backup:
|
||||||
|
model_sd[k][:] = self.backup[k]
|
||||||
|
self.backup = {}
|
||||||
|
|
||||||
|
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
|
||||||
|
key_map = model_lora_keys(model.model)
|
||||||
|
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
||||||
|
loaded = load_lora(lora_path, key_map)
|
||||||
|
new_modelpatcher = model.clone()
|
||||||
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
|
new_clip = clip.clone()
|
||||||
|
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||||
|
k = set(k)
|
||||||
|
k1 = set(k1)
|
||||||
|
for x in loaded:
|
||||||
|
if (x not in k) and (x not in k1):
|
||||||
|
print("NOT LOADED", x)
|
||||||
|
|
||||||
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, config, embedding_directory=None):
|
def __init__(self, config={}, embedding_directory=None, no_init=False):
|
||||||
|
if no_init:
|
||||||
|
return
|
||||||
self.target_clip = config["target"]
|
self.target_clip = config["target"]
|
||||||
if "params" in config:
|
if "params" in config:
|
||||||
params = config["params"]
|
params = config["params"]
|
||||||
|
@ -72,13 +225,30 @@ class CLIP:
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
self.tokenizer = tokenizer(**(tokenizer_params))
|
self.tokenizer = tokenizer(**(tokenizer_params))
|
||||||
|
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
n = CLIP(no_init=True)
|
||||||
|
n.target_clip = self.target_clip
|
||||||
|
n.patcher = self.patcher.clone()
|
||||||
|
n.cond_stage_model = self.cond_stage_model
|
||||||
|
n.tokenizer = self.tokenizer
|
||||||
|
return n
|
||||||
|
|
||||||
|
def add_patches(self, patches, strength=1.0):
|
||||||
|
return self.patcher.add_patches(patches, strength)
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
tokens = self.tokenizer.tokenize_with_weights(text)
|
tokens = self.tokenizer.tokenize_with_weights(text)
|
||||||
cond = self.cond_stage_model.encode_token_weights(tokens)
|
try:
|
||||||
|
self.patcher.patch_model()
|
||||||
|
cond = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
self.patcher.unpatch_model()
|
||||||
|
except Exception as e:
|
||||||
|
self.patcher.unpatch_model()
|
||||||
|
raise e
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
|
def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
|
||||||
if config is None:
|
if config is None:
|
||||||
|
@ -135,4 +305,4 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
|
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
return (model, clip, vae)
|
return (ModelPatcher(model), clip, vae)
|
||||||
|
|
80
nodes.py
80
nodes.py
|
@ -130,6 +130,27 @@ class CheckpointLoader:
|
||||||
embedding_directory = os.path.join(self.models_dir, "embeddings")
|
embedding_directory = os.path.join(self.models_dir, "embeddings")
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
|
||||||
|
|
||||||
|
class LoraLoader:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
lora_dir = os.path.join(models_dir, "loras")
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"lora_name": (filter_files_extensions(os.listdir(s.lora_dir), supported_pt_extensions), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP")
|
||||||
|
FUNCTION = "load_lora"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||||
|
lora_path = os.path.join(self.lora_dir, lora_name)
|
||||||
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||||
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
vae_dir = os.path.join(models_dir, "vae")
|
vae_dir = os.path.join(models_dir, "vae")
|
||||||
|
@ -268,35 +289,43 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
||||||
else:
|
else:
|
||||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
||||||
|
|
||||||
model = model.to(device)
|
try:
|
||||||
noise = noise.to(device)
|
real_model = model.patch_model()
|
||||||
latent_image = latent_image.to(device)
|
real_model.to(device)
|
||||||
|
noise = noise.to(device)
|
||||||
|
latent_image = latent_image.to(device)
|
||||||
|
|
||||||
positive_copy = []
|
positive_copy = []
|
||||||
negative_copy = []
|
negative_copy = []
|
||||||
|
|
||||||
for p in positive:
|
for p in positive:
|
||||||
t = p[0]
|
t = p[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
positive_copy += [[t] + p[1:]]
|
positive_copy += [[t] + p[1:]]
|
||||||
for n in negative:
|
for n in negative:
|
||||||
t = n[0]
|
t = n[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
negative_copy += [[t] + n[1:]]
|
negative_copy += [[t] + n[1:]]
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||||
else:
|
else:
|
||||||
#other samplers
|
#other samplers
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
||||||
|
samples = samples.cpu()
|
||||||
|
real_model.cpu()
|
||||||
|
model.unpatch_model()
|
||||||
|
except Exception as e:
|
||||||
|
real_model.cpu()
|
||||||
|
model.unpatch_model()
|
||||||
|
raise e
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
|
||||||
samples = samples.cpu()
|
|
||||||
model = model.cpu()
|
|
||||||
return (samples, )
|
return (samples, )
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
|
@ -452,6 +481,7 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"LatentComposite": LatentComposite,
|
"LatentComposite": LatentComposite,
|
||||||
"LatentRotate": LatentRotate,
|
"LatentRotate": LatentRotate,
|
||||||
"LatentFlip": LatentFlip,
|
"LatentFlip": LatentFlip,
|
||||||
|
"LoraLoader": LoraLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue