Support loading diffusers SD3 model format with UNETLoader node.
This commit is contained in:
parent
b08a9dd04b
commit
0d6a57938e
|
@ -1,7 +1,9 @@
|
|||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
|
@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict):
|
|||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
if depth > 0:
|
||||
out_sd = {}
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix)
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
t = sd_map[k]
|
||||
|
||||
if not isinstance(t, str):
|
||||
if len(t) > 2:
|
||||
fun = t[2]
|
||||
else:
|
||||
fun = lambda a: a
|
||||
offset = t[1]
|
||||
if offset is not None:
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
t = t[0]
|
||||
out_sd[t] = old_weight
|
||||
else:
|
||||
out_sd[t] = weight
|
||||
state_dict.pop(k)
|
||||
|
||||
return out_sd
|
||||
|
|
|
@ -568,7 +568,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
if new_sd is None:
|
||||
return None
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
elif "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
|
|
|
@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config):
|
|||
|
||||
return diffusers_unet_map
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
MMDIT_MAP_BASIC = {
|
||||
("context_embedder.bias", "context_embedder.bias"),
|
||||
("context_embedder.weight", "context_embedder.weight"),
|
||||
|
@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = {
|
|||
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("pos_embed", "pos_embed.pos_embed"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
}
|
||||
|
@ -313,8 +318,15 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||
for k in MMDIT_MAP_BLOCK:
|
||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||
|
||||
for k in MMDIT_MAP_BASIC:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
map_basic = MMDIT_MAP_BASIC.copy()
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
||||
|
||||
for k in map_basic:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
|
|
|
@ -52,9 +52,32 @@ class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeSD3(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embed."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["context_embedder."] = argument
|
||||
arg_dict["y_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["joint_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
"ModelMergeSDXL": ModelMergeSDXL,
|
||||
"ModelMergeSD3": ModelMergeSD3,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue