use slice instead of torch.select()

This commit is contained in:
missionfloyd 2023-04-11 20:26:24 -06:00 committed by GitHub
parent e12fb88b1b
commit e1d289c1ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -1076,7 +1076,7 @@ class ImageToMask:
def image_to_mask(self, image, channel):
channels = ["red", "green", "blue"]
mask = torch.select(image[0], 2, channels.index(channel))
mask = image[0, :, :, channels.index(channel)]
return (mask,)
class MaskToImage: