Refactor calc_cond_uncond_batch into calc_cond_batch.
calc_cond_batch can take an arbitrary amount of cond inputs. Added a calc_cond_uncond_batch wrapper with a warning so custom nodes won't break.
This commit is contained in:
parent
1306464538
commit
e6482fbbfc
|
@ -127,30 +127,23 @@ def cond_cat(c_list):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_conds = []
|
||||||
out_count = torch.ones_like(x_in) * 1e-37
|
out_counts = []
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
|
||||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
|
||||||
|
|
||||||
COND = 0
|
|
||||||
UNCOND = 1
|
|
||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
|
|
||||||
|
for i in range(len(conds)):
|
||||||
|
out_conds.append(torch.zeros_like(x_in))
|
||||||
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||||
|
|
||||||
|
cond = conds[i]
|
||||||
|
if cond is not None:
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, i)]
|
||||||
if uncond is not None:
|
|
||||||
for x in uncond:
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, UNCOND)]
|
|
||||||
|
|
||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
first = to_run[0]
|
first = to_run[0]
|
||||||
|
@ -222,22 +215,20 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
else:
|
else:
|
||||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
if cond_or_uncond[o] == COND:
|
cond_index = cond_or_uncond[o]
|
||||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||||
else:
|
|
||||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
||||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
||||||
del mult
|
|
||||||
|
|
||||||
out_cond /= out_count
|
for i in range(len(out_conds)):
|
||||||
del out_count
|
out_conds[i] /= out_counts[i]
|
||||||
out_uncond /= out_uncond_count
|
|
||||||
del out_uncond_count
|
return out_conds
|
||||||
return out_cond, out_uncond
|
|
||||||
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||||
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||||
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
|
@ -247,7 +238,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||||
else:
|
else:
|
||||||
uncond_ = uncond
|
uncond_ = uncond
|
||||||
|
|
||||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
|
||||||
|
conds = [cond, uncond_]
|
||||||
|
|
||||||
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||||
|
cond_pred = out[0]
|
||||||
|
uncond_pred = out[1]
|
||||||
|
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||||
|
|
|
@ -31,7 +31,7 @@ class PerpNeg:
|
||||||
model_options = args["model_options"]
|
model_options = args["model_options"]
|
||||||
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
||||||
|
|
||||||
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
|
(noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
|
||||||
|
|
||||||
pos = noise_pred_pos - noise_pred_nocond
|
pos = noise_pred_pos - noise_pred_nocond
|
||||||
neg = noise_pred_neg - noise_pred_nocond
|
neg = noise_pred_neg - noise_pred_nocond
|
||||||
|
|
|
@ -150,7 +150,7 @@ class SelfAttentionGuidance:
|
||||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||||
degraded_noised = degraded + x - uncond_pred
|
degraded_noised = degraded + x - uncond_pred
|
||||||
# call into the UNet
|
# call into the UNet
|
||||||
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
(sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
|
||||||
return cfg_result + (degraded - sag) * sag_scale
|
return cfg_result + (degraded - sag) * sag_scale
|
||||||
|
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
Loading…
Reference in New Issue