From ef90e9c376fd9b7ed40fe38ba6695be67d5fc9b9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 3 Feb 2023 02:06:34 -0500 Subject: [PATCH] Add a LoraLoader node to apply loras to models and clip. The models are modified in place before being used and unpatched after. I think this is better than monkeypatching since it might make it easier to use faster non pytorch unet inference in the future. --- comfy/sd.py | 186 ++++++++++++++++++++++++++++++++++-- models/loras/put_loras_here | 0 nodes.py | 80 +++++++++++----- 3 files changed, 233 insertions(+), 33 deletions(-) create mode 100644 models/loras/put_loras_here diff --git a/comfy/sd.py b/comfy/sd.py index 13776f1b..bb7023f7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,10 +6,7 @@ from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf - -def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): - print(f"Loading model from {ckpt}") - +def load_torch_file(ckpt): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") @@ -21,6 +18,12 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): sd = pl_sd["state_dict"] else: sd = pl_sd + return sd + +def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): + print(f"Loading model from {ckpt}") + + sd = load_torch_file(ckpt) model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) @@ -50,10 +53,160 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): model.eval() return model +LORA_CLIP_MAP = { + "mlp.fc1": "mlp_fc1", + "mlp.fc2": "mlp_fc2", + "self_attn.k_proj": "self_attn_k_proj", + "self_attn.q_proj": "self_attn_q_proj", + "self_attn.v_proj": "self_attn_v_proj", + "self_attn.out_proj": "self_attn_out_proj", +} + +LORA_UNET_MAP = { + "proj_in": "proj_in", + "proj_out": "proj_out", + "transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q", + "transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k", + "transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v", + "transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0", + "transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q", + "transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k", + "transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v", + "transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0", + "transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj", + "transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2", +} + + +def load_lora(path, to_load): + lora = load_torch_file(path) + patch_dict = {} + loaded_keys = set() + for x in to_load: + A_name = "{}.lora_up.weight".format(x) + B_name = "{}.lora_down.weight".format(x) + alpha_name = "{}.alpha".format(x) + if A_name in lora.keys(): + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + for x in lora.keys(): + if x not in loaded_keys: + print("lora key not loaded", x) + return patch_dict + +def model_lora_keys(model, key_map={}): + sdk = model.state_dict().keys() + + counter = 0 + for b in range(12): + tk = "model.diffusion_model.input_blocks.{}.1".format(b) + up_counter = 0 + for c in LORA_UNET_MAP: + k = "{}.{}.weight".format(tk, c) + if k in sdk: + lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c]) + key_map[lora_key] = k + up_counter += 1 + if up_counter >= 4: + counter += 1 + for c in LORA_UNET_MAP: + k = "model.diffusion_model.middle_block.1.{}.weight".format(c) + if k in sdk: + lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c]) + key_map[lora_key] = k + counter = 3 + for b in range(12): + tk = "model.diffusion_model.output_blocks.{}.1".format(b) + up_counter = 0 + for c in LORA_UNET_MAP: + k = "{}.{}.weight".format(tk, c) + if k in sdk: + lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c]) + key_map[lora_key] = k + up_counter += 1 + if up_counter >= 4: + counter += 1 + counter = 0 + for b in range(12): + for c in LORA_CLIP_MAP: + k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + return key_map + +class ModelPatcher: + def __init__(self, model): + self.model = model + self.patches = [] + self.backup = {} + + def clone(self): + n = ModelPatcher(self.model) + n.patches = self.patches[:] + return n + + def add_patches(self, patches, strength=1.0): + p = {} + model_sd = self.model.state_dict() + for k in patches: + if k in model_sd: + p[k] = patches[k] + self.patches += [(strength, p)] + return p.keys() + + def patch_model(self): + model_sd = self.model.state_dict() + for p in self.patches: + for k in p[1]: + v = p[1][k] + if k not in model_sd: + print("could not patch. key doesn't exist in model:", k) + continue + + weight = model_sd[k] + if k not in self.backup: + self.backup[k] = weight.clone() + + alpha = p[0] + mat1 = v[0] + mat2 = v[1] + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + return self.model + def unpatch_model(self): + model_sd = self.model.state_dict() + for k in self.backup: + model_sd[k][:] = self.backup[k] + self.backup = {} + +def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip): + key_map = model_lora_keys(model.model) + key_map = model_lora_keys(clip.cond_stage_model, key_map) + loaded = load_lora(lora_path, key_map) + new_modelpatcher = model.clone() + k = new_modelpatcher.add_patches(loaded, strength_model) + new_clip = clip.clone() + k1 = new_clip.add_patches(loaded, strength_clip) + k = set(k) + k1 = set(k1) + for x in loaded: + if (x not in k) and (x not in k1): + print("NOT LOADED", x) + + return (new_modelpatcher, new_clip) class CLIP: - def __init__(self, config, embedding_directory=None): + def __init__(self, config={}, embedding_directory=None, no_init=False): + if no_init: + return self.target_clip = config["target"] if "params" in config: params = config["params"] @@ -72,13 +225,30 @@ class CLIP: self.cond_stage_model = clip(**(params)) self.tokenizer = tokenizer(**(tokenizer_params)) + self.patcher = ModelPatcher(self.cond_stage_model) + + def clone(self): + n = CLIP(no_init=True) + n.target_clip = self.target_clip + n.patcher = self.patcher.clone() + n.cond_stage_model = self.cond_stage_model + n.tokenizer = self.tokenizer + return n + + def add_patches(self, patches, strength=1.0): + return self.patcher.add_patches(patches, strength) def encode(self, text): tokens = self.tokenizer.tokenize_with_weights(text) - cond = self.cond_stage_model.encode_token_weights(tokens) + try: + self.patcher.patch_model() + cond = self.cond_stage_model.encode_token_weights(tokens) + self.patcher.unpatch_model() + except Exception as e: + self.patcher.unpatch_model() + raise e return cond - class VAE: def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None): if config is None: @@ -135,4 +305,4 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e load_state_dict_to = [w] model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to) - return (model, clip, vae) + return (ModelPatcher(model), clip, vae) diff --git a/models/loras/put_loras_here b/models/loras/put_loras_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 2ad81a0c..37952e31 100644 --- a/nodes.py +++ b/nodes.py @@ -130,6 +130,27 @@ class CheckpointLoader: embedding_directory = os.path.join(self.models_dir, "embeddings") return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory) +class LoraLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + lora_dir = os.path.join(models_dir, "loras") + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip": ("CLIP", ), + "lora_name": (filter_files_extensions(os.listdir(s.lora_dir), supported_pt_extensions), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL", "CLIP") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + lora_path = os.path.join(self.lora_dir, lora_name) + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) + return (model_lora, clip_lora) + class VAELoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") vae_dir = os.path.join(models_dir, "vae") @@ -268,35 +289,43 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po else: noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") - model = model.to(device) - noise = noise.to(device) - latent_image = latent_image.to(device) + try: + real_model = model.patch_model() + real_model.to(device) + noise = noise.to(device) + latent_image = latent_image.to(device) - positive_copy = [] - negative_copy = [] + positive_copy = [] + negative_copy = [] - for p in positive: - t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - positive_copy += [[t] + p[1:]] - for n in negative: - t = n[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - negative_copy += [[t] + n[1:]] + for p in positive: + t = p[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + positive_copy += [[t] + p[1:]] + for n in negative: + t = n[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + negative_copy += [[t] + n[1:]] - if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) - else: - #other samplers - pass + if sampler_name in comfy.samplers.KSampler.SAMPLERS: + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) + else: + #other samplers + pass + + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) + samples = samples.cpu() + real_model.cpu() + model.unpatch_model() + except Exception as e: + real_model.cpu() + model.unpatch_model() + raise e - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) - samples = samples.cpu() - model = model.cpu() return (samples, ) class KSampler: @@ -452,6 +481,7 @@ NODE_CLASS_MAPPINGS = { "LatentComposite": LatentComposite, "LatentRotate": LatentRotate, "LatentFlip": LatentFlip, + "LoraLoader": LoraLoader, }