Add some new weight patching functionality.
Add a way to reshape lora weights. Allow weight patches to all weight not just .weight and .bias Add a way for a lora to set a weight to a specific value.
This commit is contained in:
parent
772e620e32
commit
41444b5236
|
@ -49,6 +49,15 @@ def load_lora(lora, to_load):
|
||||||
dora_scale = lora[dora_scale_name]
|
dora_scale = lora[dora_scale_name]
|
||||||
loaded_keys.add(dora_scale_name)
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
|
reshape_name = "{}.reshape_weight".format(x)
|
||||||
|
reshape = None
|
||||||
|
if reshape_name in lora.keys():
|
||||||
|
try:
|
||||||
|
reshape = lora[reshape_name].tolist()
|
||||||
|
loaded_keys.add(reshape_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
@ -82,7 +91,7 @@ def load_lora(lora, to_load):
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
loaded_keys.add(mid_name)
|
loaded_keys.add(mid_name)
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||||
loaded_keys.add(A_name)
|
loaded_keys.add(A_name)
|
||||||
loaded_keys.add(B_name)
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
@ -193,6 +202,12 @@ def load_lora(lora, to_load):
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||||
loaded_keys.add(diff_bias_name)
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
|
set_weight_name = "{}.set_weight".format(x)
|
||||||
|
set_weight = lora.get(set_weight_name, None)
|
||||||
|
if set_weight is not None:
|
||||||
|
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||||
|
loaded_keys.add(set_weight_name)
|
||||||
|
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
logging.warning("lora key not loaded: {}".format(x))
|
logging.warning("lora key not loaded: {}".format(x))
|
||||||
|
@ -282,11 +297,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||||
sdk = sd.keys()
|
sdk = sd.keys()
|
||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model."):
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
if k.endswith(".weight"):
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||||
|
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
else:
|
||||||
|
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||||
|
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
|
@ -440,10 +458,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
|
elif patch_type == "set":
|
||||||
|
weight.copy_(v[0])
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
dora_scale = v[4]
|
dora_scale = v[4]
|
||||||
|
reshape = v[5]
|
||||||
|
|
||||||
|
if reshape is not None:
|
||||||
|
weight = pad_tensor_to_shape(weight, reshape)
|
||||||
|
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha = v[2] / mat2.shape[0]
|
alpha = v[2] / mat2.shape[0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -373,14 +373,18 @@ class ModelPatcher:
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
params = []
|
||||||
loading.append((comfy.model_management.module_size(m), n, m))
|
for name, param in m.named_parameters(recurse=False):
|
||||||
|
params.append(name)
|
||||||
|
if hasattr(m, "comfy_cast_weights") or len(params) > 0:
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
m = x[2]
|
m = x[2]
|
||||||
|
params = x[3]
|
||||||
module_mem = x[0]
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
@ -416,22 +420,21 @@ class ModelPatcher:
|
||||||
if m.comfy_cast_weights:
|
if m.comfy_cast_weights:
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
mem_counter += module_mem
|
||||||
mem_counter += module_mem
|
load_completely.append((module_mem, n, m, params))
|
||||||
load_completely.append((module_mem, n, m))
|
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
m = x[2]
|
m = x[2]
|
||||||
weight_key = "{}.weight".format(n)
|
params = x[3]
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
if hasattr(m, "comfy_patched_weights"):
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
if m.comfy_patched_weights == True:
|
if m.comfy_patched_weights == True:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
for param in params:
|
||||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue