-
Notifications
You must be signed in to change notification settings - Fork 28.9k
Fix Pan and Scan on batched images Gemma3 #36864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Pan and Scan on batched images Gemma3 #36864
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, interesting, since we had a test for PaS in image-processors which was green. Or was the error in processing code?
# Add the thumbnails to the image patches | ||
stacked_images = [stacked_images] + pas_images | ||
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together) | ||
processed_image_patches_grouped = {} | ||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(stacked_images) | ||
for shape, stacked_image_patches in grouped_image_patches.items(): | ||
stacked_image_patches = self.resize( | ||
image=stacked_image_patches, | ||
size=size, | ||
interpolation=interpolation, | ||
) | ||
processed_image_patches_grouped[shape] = stacked_image_patches | ||
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) | ||
# Transpose to have the thumbnails with their corresponding patches | ||
stacked_images = torch.stack(processed_image_patches, dim=0).transpose(0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR particularly. Seeing a second grouping to sizes while we are using a grouped image batch, leads me to believe the batching logic in fast processors are over-complicated. Would be nice if we can simplify stuff, especially for community contributors when they add a new model with a new processing like PaS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this doesn't look great agreed. This is quite a specific case where we start with images of potentially different sizes, then we split them in patches, and concatenate the patches with the original image (which is bigger than the patches), before resizing all to the same size 😅.
This new code is a bit overkill, where we group the patches and images by size at every step, but it's not really necessary to have at least in the first implementation by external contributors, so hopefully they won't ever have to do that to get a working fast image processor.
Not sure if there's a simpler way to fully use batch processing, I guess it's case by case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, no rush to work on that. I don't think we have been forcing users to add only fast processors for now. We can come back to this question later. Maybe we'll find a better way to batch or we add some guides about how to add special processing in fast image processors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
➕ on spending time to write simpler code! We can find some magic here and there.
In this specific case, I don't necessarily know. But, sometimes padding then unpadding can lead to better perfs.
- Ordering can be costly
- Depending on the size of padding, we are not really adding too much compute
- distribution of image size is important to take into account!
I don't think PaS was tested with batched inputs was it? Edit: I see it was, but the problem is really with num_crops, where for example fi you have image inputs like [[image1, image2], [image3]], the num_crops returned will be [[2, 2], [2]] which will crash when trying to return pt tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right, we didn't test different number of images per batch in PaS. As long as the gemma3 tests are green, lgtm. Thanks for fixing and the new test!
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") | ||
|
||
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops) | ||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use torch.testing.assert_close
here also, afair that works better for tensor match tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to see if i can make it work but I've had some issues with torch.testing.assert_close
when comparing pixel values, because some rtol
can be very high and it's difficult to choose a rtol
value that will work for all processors (the base tests comparing slow and fast also don't use torch.testing.assert_close
to compare pixel values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks as half of the changes make code simpler! (falt list ofimage vs nested!)
# Add the thumbnails to the image patches | ||
stacked_images = [stacked_images] + pas_images | ||
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together) | ||
processed_image_patches_grouped = {} | ||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(stacked_images) | ||
for shape, stacked_image_patches in grouped_image_patches.items(): | ||
stacked_image_patches = self.resize( | ||
image=stacked_image_patches, | ||
size=size, | ||
interpolation=interpolation, | ||
) | ||
processed_image_patches_grouped[shape] = stacked_image_patches | ||
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) | ||
# Transpose to have the thumbnails with their corresponding patches | ||
stacked_images = torch.stack(processed_image_patches, dim=0).transpose(0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
➕ on spending time to write simpler code! We can find some magic here and there.
In this specific case, I don't necessarily know. But, sometimes padding then unpadding can lead to better perfs.
- Ordering can be costly
- Depending on the size of padding, we are not really adding too much compute
- distribution of image size is important to take into account!
True! In general I think now that we have several techniques for fast processing, I need to make a new benchmark to compare those for each models (batched vs unbatched, padded vs unpadded, different techniques for splitting images into patches etc.). |
…zlan/transformers into fix-pas-batch-proc-gemma3
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
run-slow: gemma3 |
This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs: models: ['models/gemma3'] |
@zucchini-nlp CI won't run the integration tests because of require_read_token |
It worked, we have asked access to gemma-3 already. All green, cool |
Ah great, merging then |
* process flattened images in fast image proc * process flattened images in low proc and add tests * remove print * add unbalanced batch test pas image proc * fix integration tests
What does this PR do?
Currently, inputs such as this one:
will crash with both slow and fast image processor.
Non-batched inputs with pan and scan will also fail with fast image processors.
This PR fixes both issue, and simplify the image processing by processing flattened images instead of nested ones.
This PR also introduces some changes to take better advantages of batch processing in the fast image processor, with a batched pan_and_scan method.
Also add some tests.
Cc @zucchini-nlp @RyanMullins