diff --git a/comfy/lora.py b/comfy/lora.py index b745ca4d..18602f24 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -49,6 +49,15 @@ def load_lora(lora, to_load): dora_scale = lora[dora_scale_name] loaded_keys.add(dora_scale_name) + reshape_name = "{}.reshape_weight".format(x) + reshape = None + if reshape_name in lora.keys(): + try: + reshape = lora[reshape_name].tolist() + loaded_keys.add(reshape_name) + except: + pass + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) diffusers2_lora = "{}.lora_B.weight".format(x) @@ -82,7 +91,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -193,6 +202,12 @@ def load_lora(lora, to_load): patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) + set_weight_name = "{}.set_weight".format(x) + set_weight = lora.get(set_weight_name, None) + if set_weight is not None: + patch_dict[to_load[x]] = ("set", (set_weight,)) + loaded_keys.add(set_weight_name) + for x in lora.keys(): if x not in loaded_keys: logging.warning("lora key not loaded: {}".format(x)) @@ -282,11 +297,14 @@ def model_lora_keys_unet(model, key_map={}): sdk = sd.keys() for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config - key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + if k.startswith("diffusion_model."): + if k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config + key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + else: + key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: @@ -440,10 +458,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) else: weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) + elif patch_type == "set": + weight.copy_(v[0]) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) dora_scale = v[4] + reshape = v[5] + + if reshape is not None: + weight = pad_tensor_to_shape(weight, reshape) + if v[2] is not None: alpha = v[2] / mat2.shape[0] else: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3bba217a..22de7eea 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -373,14 +373,18 @@ class ModelPatcher: lowvram_counter = 0 loading = [] for n, m in self.model.named_modules(): - if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"): - loading.append((comfy.model_management.module_size(m), n, m)) + params = [] + for name, param in m.named_parameters(recurse=False): + params.append(name) + if hasattr(m, "comfy_cast_weights") or len(params) > 0: + loading.append((comfy.model_management.module_size(m), n, m, params)) load_completely = [] loading.sort(reverse=True) for x in loading: n = x[1] m = x[2] + params = x[3] module_mem = x[0] lowvram_weight = False @@ -416,22 +420,21 @@ class ModelPatcher: if m.comfy_cast_weights: wipe_lowvram_weight(m) - if hasattr(m, "weight"): - mem_counter += module_mem - load_completely.append((module_mem, n, m)) + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) load_completely.sort(reverse=True) for x in load_completely: n = x[1] m = x[2] - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + params = x[3] if hasattr(m, "comfy_patched_weights"): if m.comfy_patched_weights == True: continue - self.patch_weight_to_device(weight_key, device_to=device_to) - self.patch_weight_to_device(bias_key, device_to=device_to) + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True