diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index d16c49ae..f692945a 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, "softsign": torch.nn.Softsign, + "mish": torch.nn.Mish, } if activation_func not in valid_activation: @@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength): linears = list(map(lambda a: a[:-len(".weight")], linears)) layers = [] - for i in range(len(linears)): + i = 0 + while i < len(linears): lin_name = linears[i] last_layer = (i == (len(linears) - 1)) penultimate_layer = (i == (len(linears) - 2)) @@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength): if (not last_layer) or (activate_output): layers.append(valid_activation[activation_func]()) if is_layer_norm: - layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + i += 1 + ln_name = linears[i] + ln_weight = attn_weights['{}.weight'.format(ln_name)] + ln_bias = attn_weights['{}.bias'.format(ln_name)] + ln = torch.nn.LayerNorm(ln_weight.shape[0]) + ln.load_state_dict({"weight": ln_weight, "bias": ln_bias}) + layers.append(ln) if use_dropout: if (not last_layer) and (not penultimate_layer or last_layer_dropout): layers.append(torch.nn.Dropout(p=0.3)) + i += 1 output.append(torch.nn.Sequential(*layers)) out[dim] = torch.nn.ModuleList(output)