Add a TomePatchModel node to the _for_testing section.
Tome increases sampling speed at the expense of quality.
This commit is contained in:
parent
7e682784d7
commit
18a6c1db33
|
@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||
|
||||
import model_management
|
||||
|
||||
from . import tomesd
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
|
@ -508,8 +509,18 @@ class BasicTransformerBlock(nn.Module):
|
|||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
n = self.norm1(x)
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||
else:
|
||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
|
||||
x += n
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Callable
|
||||
import math
|
||||
|
||||
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||
return x
|
||||
|
||||
|
||||
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
w: int, h: int, sx: int, sy: int, r: int,
|
||||
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
||||
"""
|
||||
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||
|
||||
Args:
|
||||
- metric [B, N, C]: metric to use for similarity
|
||||
- w: image width in tokens
|
||||
- h: image height in tokens
|
||||
- sx: stride in the x dimension for dst, must divide w
|
||||
- sy: stride in the y dimension for dst, must divide h
|
||||
- r: number of tokens to remove (by merging)
|
||||
- no_rand: if true, disable randomness (use top left corner only)
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
|
||||
|
||||
if no_rand:
|
||||
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)
|
||||
|
||||
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
|
||||
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
|
||||
rand_idx = idx_buffer.argsort(dim=1)
|
||||
|
||||
num_dst = int((1 / (sx*sy)) * N)
|
||||
a_idx = rand_idx[:, num_dst:, :] # src
|
||||
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||
|
||||
def split(x):
|
||||
C = x.shape[-1]
|
||||
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
|
||||
return src, dst
|
||||
|
||||
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||
a, b = split(metric)
|
||||
scores = a @ b.transpose(-1, -2)
|
||||
|
||||
# Can't reduce more than the # tokens in src
|
||||
r = min(a.shape[1], r)
|
||||
|
||||
node_max, node_idx = scores.max(dim=-1)
|
||||
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||
|
||||
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
|
||||
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
return torch.cat([unm, dst], dim=1)
|
||||
|
||||
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||
unm_len = unm_idx.shape[1]
|
||||
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||
_, _, c = unm.shape
|
||||
|
||||
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))
|
||||
|
||||
# Combine back to the original shape
|
||||
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||
|
||||
return out
|
||||
|
||||
return merge, unmerge
|
||||
|
||||
|
||||
def get_functions(x, ratio, original_shape):
|
||||
b, c, original_h, original_w = original_shape
|
||||
original_tokens = original_h * original_w
|
||||
downsample = int(math.sqrt(original_tokens // x.shape[1]))
|
||||
stride_x = 2
|
||||
stride_y = 2
|
||||
max_downsample = 1
|
||||
|
||||
if downsample <= max_downsample:
|
||||
w = original_w // downsample
|
||||
h = original_h // downsample
|
||||
r = int(x.shape[1] * ratio)
|
||||
no_rand = True
|
||||
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||
return m, u
|
||||
|
||||
nothing = lambda y: y
|
||||
return nothing, nothing
|
|
@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
|
||||
|
@ -195,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
|
@ -212,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -221,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
|
||||
if denoise_mask is not None:
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
|
@ -333,7 +333,7 @@ class KSampler:
|
|||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
self.model_denoise = CFGNoisePredictor(self.model)
|
||||
if self.model.parameterization == "v":
|
||||
|
@ -353,6 +353,7 @@ class KSampler:
|
|||
self.sigma_max=float(self.model_wrap.sigma_max)
|
||||
self.set_steps(steps, denoise)
|
||||
self.denoise = denoise
|
||||
self.model_options = model_options
|
||||
|
||||
def _calculate_sigmas(self, steps):
|
||||
sigmas = None
|
||||
|
@ -421,7 +422,7 @@ class KSampler:
|
|||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(self.model, 'concat_keys'):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
import sd1_clip
|
||||
import sd2_clip
|
||||
|
@ -274,12 +275,20 @@ class ModelPatcher:
|
|||
self.model = model
|
||||
self.patches = []
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model)
|
||||
n.patches = self.patches[:]
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
|
||||
def set_model_tomesd(self, ratio):
|
||||
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
||||
|
||||
def model_dtype(self):
|
||||
return self.model.diffusion_model.dtype
|
||||
|
||||
def add_patches(self, patches, strength=1.0):
|
||||
p = {}
|
||||
model_sd = self.model.state_dict()
|
||||
|
|
19
nodes.py
19
nodes.py
|
@ -254,6 +254,22 @@ class LoraLoader:
|
|||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class TomePatchModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def patch(self, model, ratio):
|
||||
m = model.clone()
|
||||
m.set_model_tomesd(ratio)
|
||||
return (m, )
|
||||
|
||||
class VAELoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
model_management.load_controlnet_gpu(control_net_models)
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
else:
|
||||
#other samplers
|
||||
pass
|
||||
|
@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"CLIPVisionLoader": CLIPVisionLoader,
|
||||
"VAEDecodeTiled": VAEDecodeTiled,
|
||||
"VAEEncodeTiled": VAEEncodeTiled,
|
||||
"TomePatchModel": TomePatchModel,
|
||||
}
|
||||
|
||||
def load_custom_node(module_path):
|
||||
|
|
Loading…
Reference in New Issue