Support sampling non 2D latents.
This commit is contained in:
parent
742d5720d1
commit
a5e6a632f9
|
@ -8,7 +8,8 @@ import logging
|
|||
import comfy.sampler_helpers
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
dims = tuple(x_in.shape[2:])
|
||||
area = None
|
||||
strength = 1.0
|
||||
|
||||
if 'timestep_start' in conds:
|
||||
|
@ -20,11 +21,16 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in conds:
|
||||
area = conds['area']
|
||||
area = list(conds['area'])
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
input_x = x_in
|
||||
if area is not None:
|
||||
for i in range(len(dims)):
|
||||
area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i])
|
||||
input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i])
|
||||
|
||||
if 'mask' in conds:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
|
@ -32,28 +38,30 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
assert(mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
for i in range(len(dims)):
|
||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||
|
||||
mask = mask * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in conds:
|
||||
if 'mask' not in conds and area is not None:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||
if (area[0] + area[2]) < x_in.shape[2]:
|
||||
for t in range(rr):
|
||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||
if area[3] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
for i in range(len(dims)):
|
||||
if area[len(dims) + i] != 0:
|
||||
for t in range(rr):
|
||||
m = mult.narrow(i + 2, t, 1)
|
||||
m *= ((1.0/rr) * (t + 1))
|
||||
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
||||
for t in range(rr):
|
||||
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
||||
m *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
|
@ -219,8 +227,19 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|||
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
@ -335,7 +354,7 @@ def get_mask_aabb(masks):
|
|||
|
||||
return bounding_boxes, is_empty
|
||||
|
||||
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||
for i in range(len(conditions)):
|
||||
|
@ -344,7 +363,14 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||
area = c['area']
|
||||
if area[0] == "percentage":
|
||||
modified = c.copy()
|
||||
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
||||
a = area[1:]
|
||||
a_len = len(a) // 2
|
||||
area = ()
|
||||
for d in range(len(dims)):
|
||||
area += (max(1, round(a[d] * dims[d])),)
|
||||
for d in range(len(dims)):
|
||||
area += (round(a[d + a_len] * dims[d]),)
|
||||
|
||||
modified['area'] = area
|
||||
c = modified
|
||||
conditions[i] = c
|
||||
|
@ -353,12 +379,12 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||
mask = c['mask']
|
||||
mask = mask.to(device=device)
|
||||
modified = c.copy()
|
||||
if len(mask.shape) == 2:
|
||||
if len(mask.shape) == len(dims):
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != h or mask.shape[2] != w:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||
if mask.shape[1:] != dims:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||
|
||||
if modified.get("set_area_to_bounds", False):
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
boxes, is_empty = get_mask_aabb(bounds)
|
||||
if is_empty[0]:
|
||||
|
@ -375,7 +401,11 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||
modified['mask'] = mask
|
||||
conditions[i] = modified
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
||||
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
|
||||
if 'area' not in c:
|
||||
return
|
||||
|
||||
|
@ -479,7 +509,10 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
|
|||
params = x.copy()
|
||||
params["device"] = device
|
||||
params["noise"] = noise
|
||||
params["width"] = params.get("width", noise.shape[3] * 8)
|
||||
default_width = None
|
||||
if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model
|
||||
default_width = noise.shape[3] * 8
|
||||
params["width"] = params.get("width", default_width)
|
||||
params["height"] = params.get("height", noise.shape[2] * 8)
|
||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||
for k in kwargs:
|
||||
|
@ -567,7 +600,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
|||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
|
||||
for k in conds:
|
||||
calculate_start_end_timesteps(model, conds[k])
|
||||
|
|
Loading…
Reference in New Issue