Implement Linear hypernetworks.
Add a HypernetworkLoader node to use hypernetworks.
This commit is contained in:
parent
6908f9c949
commit
5282f56434
|
@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
query = self.to_q(x)
|
query = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
|
if value is not None:
|
||||||
|
value = self.to_v(value)
|
||||||
|
else:
|
||||||
value = self.to_v(context)
|
value = self.to_v(context)
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
@ -256,12 +260,16 @@ class CrossAttentionDoggettx(nn.Module):
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k_in = self.to_k(context)
|
k_in = self.to_k(context)
|
||||||
|
if value is not None:
|
||||||
|
v_in = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
v_in = self.to_v(context)
|
v_in = self.to_v(context)
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
|
@ -350,12 +358,16 @@ class CrossAttention(nn.Module):
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
@ -402,10 +414,14 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
|
@ -447,10 +463,14 @@ class CrossAttentionPytorch(nn.Module):
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
|
@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
|
||||||
transformer_patches = {}
|
transformer_patches = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
|
if self.disable_self_attn:
|
||||||
|
context_attn1 = context
|
||||||
|
else:
|
||||||
|
context_attn1 = None
|
||||||
|
value_attn1 = None
|
||||||
|
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
if context_attn1 is None:
|
||||||
|
context_attn1 = n
|
||||||
|
value_attn1 = context_attn1
|
||||||
|
for p in patch:
|
||||||
|
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
|
@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
|
||||||
x = p(current_index, x)
|
x = p(current_index, x)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
n = self.attn2(n, context=context)
|
|
||||||
|
context_attn2 = context
|
||||||
|
value_attn2 = None
|
||||||
|
if "attn2_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn2_patch"]
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
for p in patch:
|
||||||
|
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||||
|
|
||||||
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
|
|
@ -133,6 +133,7 @@ def unload_model():
|
||||||
#never unload models from GPU on high vram
|
#never unload models from GPU on high vram
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
|
current_loaded_model.model_patches_to("cpu")
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
|
@ -156,6 +157,8 @@ def load_model_gpu(model):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
model.model_patches_to(get_torch_device())
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -197,6 +197,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
transformer_options = model_options['transformer_options'].copy()
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
|
if "patches" in transformer_options:
|
||||||
|
cur_patches = transformer_options["patches"].copy()
|
||||||
|
for p in patches:
|
||||||
|
if p in cur_patches:
|
||||||
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
|
else:
|
||||||
|
cur_patches[p] = patches[p]
|
||||||
|
else:
|
||||||
transformer_options["patches"] = patches
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
23
comfy/sd.py
23
comfy/sd.py
|
@ -254,6 +254,29 @@ class ModelPatcher:
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_patch(self, patch, name):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" not in to:
|
||||||
|
to["patches"] = {}
|
||||||
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_attn1_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def model_patches_to(self, device):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.diffusion_model.dtype
|
return self.model.diffusion_model.dtype
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def load_torch_file(ckpt):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
|
else:
|
||||||
|
if safe_load:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def load_hypernetwork_patch(path, strength):
|
||||||
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
|
activation_func = sd.get('activation_func', 'linear')
|
||||||
|
is_layer_norm = sd.get('is_layer_norm', False)
|
||||||
|
use_dropout = sd.get('use_dropout', False)
|
||||||
|
activate_output = sd.get('activate_output', False)
|
||||||
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||||
|
|
||||||
|
if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False:
|
||||||
|
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
||||||
|
return None
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
for d in sd:
|
||||||
|
try:
|
||||||
|
dim = int(d)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for index in [0, 1]:
|
||||||
|
attn_weights = sd[dim][index]
|
||||||
|
keys = attn_weights.keys()
|
||||||
|
|
||||||
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||||
|
linears = sorted(list(map(lambda a: a[:-len(".weight")], linears)))
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for lin_name in linears:
|
||||||
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||||
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||||
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||||
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||||
|
layers += [layer]
|
||||||
|
|
||||||
|
output.append(torch.nn.Sequential(*layers))
|
||||||
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
|
||||||
|
class hypernetwork_patch:
|
||||||
|
def __init__(self, hypernet, strength):
|
||||||
|
self.hypernet = hypernet
|
||||||
|
self.strength = strength
|
||||||
|
def __call__(self, current_index, q, k, v):
|
||||||
|
dim = k.shape[-1]
|
||||||
|
if dim in self.hypernet:
|
||||||
|
hn = self.hypernet[dim]
|
||||||
|
k = k + hn[0](k) * self.strength
|
||||||
|
v = v + hn[1](v) * self.strength
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for d in self.hypernet.keys():
|
||||||
|
self.hypernet[d] = self.hypernet[d].to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
|
class HypernetworkLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "load_hypernetwork"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
|
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||||
|
model_hypernetwork = model.clone()
|
||||||
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
|
if patch is not None:
|
||||||
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
|
return (model_hypernetwork,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"HypernetworkLoader": HypernetworkLoader
|
||||||
|
}
|
|
@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m
|
||||||
|
|
||||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||||
|
|
||||||
|
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||||
|
|
||||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||||
|
|
1
nodes.py
1
nodes.py
|
@ -1268,6 +1268,7 @@ def load_custom_nodes():
|
||||||
|
|
||||||
def init_custom_nodes():
|
def init_custom_nodes():
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
|
|
Loading…
Reference in New Issue