Fix error when more cond masks passed in than batch size (#3353)

This commit is contained in:
Jedrzej Kosinski 2024-04-26 11:51:12 -05:00 committed by GitHub
parent 16eabdf70d
commit 7990ae18c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -34,7 +34,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)