Support hypernetwork with mish activation function and layer norm.

This commit is contained in:
comfyanonymous 2023-10-17 12:08:03 -04:00
parent 92f0318630
commit f8caa24bcc
1 changed files with 11 additions and 2 deletions

View File

@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign, "softsign": torch.nn.Softsign,
"mish": torch.nn.Mish,
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
linears = list(map(lambda a: a[:-len(".weight")], linears)) linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = [] layers = []
for i in range(len(linears)): i = 0
while i < len(linears):
lin_name = linears[i] lin_name = linears[i]
last_layer = (i == (len(linears) - 1)) last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2)) penultimate_layer = (i == (len(linears) - 2))
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
if (not last_layer) or (activate_output): if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]()) layers.append(valid_activation[activation_func]())
if is_layer_norm: if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) i += 1
ln_name = linears[i]
ln_weight = attn_weights['{}.weight'.format(ln_name)]
ln_bias = attn_weights['{}.bias'.format(ln_name)]
ln = torch.nn.LayerNorm(ln_weight.shape[0])
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
layers.append(ln)
if use_dropout: if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout): if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3)) layers.append(torch.nn.Dropout(p=0.3))
i += 1
output.append(torch.nn.Sequential(*layers)) output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output) out[dim] = torch.nn.ModuleList(output)