Refactor unCLIP noise augment out of samplers.py
This commit is contained in:
parent
7b2f09b5fa
commit
c64ca8c0b2
|
@ -60,6 +60,37 @@ class SD21UNCLIP(BaseModel):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(unet_config, v_prediction)
|
||||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if unclip_conditioning is not None:
|
||||||
|
adm_inputs = []
|
||||||
|
weights = []
|
||||||
|
noise_aug = []
|
||||||
|
for unclip_cond in unclip_conditioning:
|
||||||
|
adm_cond = unclip_cond["clip_vision_output"].image_embeds
|
||||||
|
weight = unclip_cond["strength"]
|
||||||
|
noise_augment = unclip_cond["noise_augmentation"]
|
||||||
|
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||||
|
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||||
|
weights.append(weight)
|
||||||
|
noise_aug.append(noise_augment)
|
||||||
|
adm_inputs.append(adm_out)
|
||||||
|
|
||||||
|
if len(noise_aug) > 1:
|
||||||
|
adm_out = torch.stack(adm_inputs).sum(0)
|
||||||
|
#TODO: add a way to control this
|
||||||
|
noise_augment = 0.05
|
||||||
|
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||||
|
c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
else:
|
||||||
|
adm_out = torch.zeros((1, self.adm_channels))
|
||||||
|
|
||||||
|
return adm_out
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
class SDInpaint(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, unet_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(unet_config, v_prediction)
|
||||||
|
|
|
@ -460,42 +460,18 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
uncond[temp[1]] = [o[0], n]
|
uncond[temp[1]] = [o[0], n]
|
||||||
|
|
||||||
|
|
||||||
def encode_adm(conds, batch_size, device, noise_augmentor=None):
|
def encode_adm(model, conds, batch_size, device):
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
adm_out = None
|
adm_out = None
|
||||||
if noise_augmentor is not None:
|
if 'adm' in x[1]:
|
||||||
if 'adm' in x[1]:
|
adm_out = x[1]["adm"]
|
||||||
adm_inputs = []
|
|
||||||
weights = []
|
|
||||||
noise_aug = []
|
|
||||||
adm_in = x[1]["adm"]
|
|
||||||
for adm_c in adm_in:
|
|
||||||
adm_cond = adm_c[0].image_embeds
|
|
||||||
weight = adm_c[1]
|
|
||||||
noise_augment = adm_c[2]
|
|
||||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
||||||
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
|
||||||
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
|
||||||
weights.append(weight)
|
|
||||||
noise_aug.append(noise_augment)
|
|
||||||
adm_inputs.append(adm_out)
|
|
||||||
|
|
||||||
if len(noise_aug) > 1:
|
|
||||||
adm_out = torch.stack(adm_inputs).sum(0)
|
|
||||||
#TODO: add a way to control this
|
|
||||||
noise_augment = 0.05
|
|
||||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
||||||
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
|
||||||
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
|
||||||
else:
|
|
||||||
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
|
||||||
else:
|
else:
|
||||||
if 'adm' in x[1]:
|
params = x[1].copy()
|
||||||
adm_out = x[1]["adm"].to(device)
|
adm_out = model.encode_adm(device=device, **params)
|
||||||
if adm_out is not None:
|
if adm_out is not None:
|
||||||
x[1] = x[1].copy()
|
x[1] = x[1].copy()
|
||||||
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device)
|
||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
@ -603,11 +579,8 @@ class KSampler:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
if self.model.is_adm():
|
if self.model.is_adm():
|
||||||
noise_augmentor = None
|
positive = encode_adm(self.model, positive, noise.shape[0], self.device)
|
||||||
if hasattr(self.model, 'noise_augmentor'): #unclip
|
negative = encode_adm(self.model, negative, noise.shape[0], self.device)
|
||||||
noise_augmentor = self.model.noise_augmentor
|
|
||||||
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
|
|
||||||
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
|
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||||
|
|
||||||
|
|
8
nodes.py
8
nodes.py
|
@ -623,11 +623,11 @@ class unCLIPConditioning:
|
||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
o = t[1].copy()
|
o = t[1].copy()
|
||||||
x = (clip_vision_output, strength, noise_augmentation)
|
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
|
||||||
if "adm" in o:
|
if "unclip_conditioning" in o:
|
||||||
o["adm"] = o["adm"][:] + [x]
|
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
|
||||||
else:
|
else:
|
||||||
o["adm"] = [x]
|
o["unclip_conditioning"] = [x]
|
||||||
n = [t[0], o]
|
n = [t[0], o]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
Loading…
Reference in New Issue