Support hypernetwork with mish activation function and layer norm.
This commit is contained in:
parent
92f0318630
commit
f8caa24bcc
|
@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
|
||||||
"tanh": torch.nn.Tanh,
|
"tanh": torch.nn.Tanh,
|
||||||
"sigmoid": torch.nn.Sigmoid,
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
"softsign": torch.nn.Softsign,
|
"softsign": torch.nn.Softsign,
|
||||||
|
"mish": torch.nn.Mish,
|
||||||
}
|
}
|
||||||
|
|
||||||
if activation_func not in valid_activation:
|
if activation_func not in valid_activation:
|
||||||
|
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
|
||||||
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for i in range(len(linears)):
|
i = 0
|
||||||
|
while i < len(linears):
|
||||||
lin_name = linears[i]
|
lin_name = linears[i]
|
||||||
last_layer = (i == (len(linears) - 1))
|
last_layer = (i == (len(linears) - 1))
|
||||||
penultimate_layer = (i == (len(linears) - 2))
|
penultimate_layer = (i == (len(linears) - 2))
|
||||||
|
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
|
||||||
if (not last_layer) or (activate_output):
|
if (not last_layer) or (activate_output):
|
||||||
layers.append(valid_activation[activation_func]())
|
layers.append(valid_activation[activation_func]())
|
||||||
if is_layer_norm:
|
if is_layer_norm:
|
||||||
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
i += 1
|
||||||
|
ln_name = linears[i]
|
||||||
|
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
||||||
|
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
||||||
|
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
||||||
|
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
||||||
|
layers.append(ln)
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||||
layers.append(torch.nn.Dropout(p=0.3))
|
layers.append(torch.nn.Dropout(p=0.3))
|
||||||
|
i += 1
|
||||||
|
|
||||||
output.append(torch.nn.Sequential(*layers))
|
output.append(torch.nn.Sequential(*layers))
|
||||||
out[dim] = torch.nn.ModuleList(output)
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
|
Loading…
Reference in New Issue