Fix to LoadImage Node for #3416 HDR images loading additional smaller… (#3454)

* Fix to LoadImage Node for #3416 HDR images loading additional smaller images. 

Added a blocking if statement  in the ImageSequence.Iterator that checks if subsequent images after the first match dimensionally, and prevent them from being appended to output_images if they do not match. 

This does not fix or change current behavior for PIL 10.2.0 where the images are loaded at the same size, but it does for 10.3.0 where they are loaded at their correct smaller sizes.

* added list of excluded formats that should return 1 image

added an explicit check for the image format so that additional formats can be added to the list that have problematic behavior.
This commit is contained in:
shawnington 2024-05-12 04:07:38 -07:00 committed by GitHub
parent f509c6fe21
commit 22edd3add5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 1 deletions

View File

@ -1461,12 +1461,24 @@ class LoadImage:
output_images = [] output_images = []
output_masks = [] output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I': if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255)) i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
if 'A' in i.getbands(): if 'A' in i.getbands():
@ -1477,7 +1489,7 @@ class LoadImage:
output_images.append(image) output_images.append(image)
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1: if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0) output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0) output_mask = torch.cat(output_masks, dim=0)
else: else: