Fix error when more cond masks passed in than batch size (#3353)
This commit is contained in:
parent
16eabdf70d
commit
7990ae18c1
|
@ -34,7 +34,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
|
|
Loading…
Reference in New Issue