Implement Differential Diffusion (#2876)

* Implement Differential Diffusion

* Cleanup.

* Fix.

* Masks should be applied at full strength.

* Fix colors.

* Register the node.

* Cleaner code.

* Fix issue with getting unipc sampler.

* Adjust thresholds.

* Switch to linear thresholds.

* Only calculate nearest_idx on valid thresholds.
This commit is contained in:
Shiimizu 2024-03-03 12:34:13 -08:00 committed by GitHub
parent 1abf8374ec
commit 727021bdea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 100 additions and 0 deletions

View File

@ -277,6 +277,8 @@ class KSamplerX0Inpaint(torch.nn.Module):
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask)
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)

View File

@ -0,0 +1,97 @@
# code adapted from https://github.com/exx8/differential-diffusion
import torch
import inspect
class DifferentialDiffusion():
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
@classmethod
def IS_CHANGED(s, *args, **kwargs):
DifferentialDiffusion.INIT = s.INIT = True
return ""
def __init__(self) -> None:
DifferentialDiffusion.INIT = False
self.sigmas: torch.Tensor = None
self.thresholds: torch.Tensor = None
self.mask_i = None
self.valid_sigmas = False
self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"]
def apply(self, model):
model = model.clone()
model.model_options["denoise_mask_function"] = self.forward
return (model,)
def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
self.__init__()
self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "")
self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler
if self.sigmas is None:
self.sigmas = sigma[:1].repeat(2)
self.sigmas[-1].zero_()
self.sigmas_min = self.sigmas.min()
self.sigmas_max = self.sigmas.max()
self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device)
self.thresholds_min_len = self.thresholds.shape[0] - 1
if self.valid_sigmas:
thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1)
mask = denoise_mask.unsqueeze(0)
mask = (mask >= thresholds).to(denoise_mask.dtype)
self.mask_i = iter(mask)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
if self.sigmas is None or DifferentialDiffusion.INIT:
self.init_sigmas(sigma, denoise_mask)
if self.valid_sigmas:
try:
return next(self.mask_i)
except StopIteration:
self.valid_sigmas = False
if self.thresholds_min_len > 1:
nearest_idx = (self.sigmas - sigma[0]).abs().argmin()
if not self.thresholds_min_len > nearest_idx:
nearest_idx = -2
threshold = self.thresholds[nearest_idx]
else:
threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min)
return (denoise_mask >= threshold).to(denoise_mask.dtype)
def get_sigmas_and_sampler(frame, target):
found = frame.f_locals[target]
if isinstance(found, torch.Tensor) and found[-1] < 0.1:
return found, frame.f_code.co_name
return False
def find_outer_instance(target: str, target_type=None, callback=None):
frame = inspect.currentframe()
i = 0
while frame and i < 100:
if target in frame.f_locals:
if callback is not None:
res = callback(frame, target)
if res:
return res
else:
found = frame.f_locals[target]
if isinstance(found, target_type):
return found
frame = frame.f_back
i += 1
return None
NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}

View File

@ -1961,6 +1961,7 @@ def init_custom_nodes():
"nodes_photomaker.py",
"nodes_cond.py",
"nodes_stable_cascade.py",
"nodes_differential_diffusion.py",
]
for node_file in extras_files: