Refactor calc_cond_uncond_batch into calc_cond_batch.

calc_cond_batch can take an arbitrary amount of cond inputs.

Added a calc_cond_uncond_batch wrapper with a warning so custom nodes
won't break.
This commit is contained in:
comfyanonymous 2024-04-01 17:23:07 -04:00
parent 1306464538
commit e6482fbbfc
3 changed files with 34 additions and 37 deletions

View File

@ -127,30 +127,23 @@ def cond_cat(c_list):
return out return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def calc_cond_batch(model, conds, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in) out_conds = []
out_count = torch.ones_like(x_in) * 1e-37 out_counts = []
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = [] to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)] for i in range(len(conds)):
if uncond is not None: out_conds.append(torch.zeros_like(x_in))
for x in uncond: out_counts.append(torch.ones_like(x_in) * 1e-37)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)] cond = conds[i]
if cond is not None:
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, i)]
while len(to_run) > 0: while len(to_run) > 0:
first = to_run[0] first = to_run[0]
@ -222,22 +215,20 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
if cond_or_uncond[o] == COND: cond_index = cond_or_uncond[o]
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count for i in range(len(out_conds)):
del out_count out_conds[i] /= out_counts[i]
out_uncond /= out_uncond_count
del out_uncond_count return out_conds
return out_cond, out_uncond
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
@ -247,7 +238,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
else: else:
uncond_ = uncond uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
cond_pred = out[0]
uncond_pred = out[1]
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}

View File

@ -31,7 +31,7 @@ class PerpNeg:
model_options = args["model_options"] model_options = args["model_options"]
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond

View File

@ -150,7 +150,7 @@ class SelfAttentionGuidance:
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
# call into the UNet # call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)