-
Notifications
You must be signed in to change notification settings - Fork 30.5k
Fix Pan and Scan on batched images Gemma3 #36864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b780c4e
e309045
4a6c28d
0aba6c8
4550cb3
9338947
2d79716
c2152cc
e3ce1d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,6 +189,13 @@ def test_pan_and_scan(self): | |
expected_output_image_shape = (9, 3, 18, 18) | ||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) | ||
|
||
# Test batched unbalanced, 9 images because we have base image + 2 crops per each item | ||
encoded_images = image_processing( | ||
[[image_inputs[0], image_inputs[1]], [image_inputs[2]]], return_tensors="pt" | ||
).pixel_values | ||
expected_output_image_shape = (9, 3, 18, 18) | ||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) | ||
|
||
def test_call_pil(self): | ||
for image_processing_class in self.image_processor_list: | ||
# Initialize image_processing | ||
|
@@ -250,3 +257,37 @@ def test_call_pytorch(self): | |
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method") | ||
def test_call_numpy_4_channels(self): | ||
pass | ||
|
||
@require_vision | ||
@require_torch | ||
def test_slow_fast_equivalence_batched_pas(self): | ||
if not self.test_slow_image_processor or not self.test_fast_image_processor: | ||
self.skipTest(reason="Skipping slow/fast equivalence test") | ||
|
||
if self.image_processing_class is None or self.fast_image_processing_class is None: | ||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") | ||
|
||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: | ||
self.skipTest( | ||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" | ||
) | ||
crop_config = { | ||
"do_pan_and_scan": True, | ||
"pan_and_scan_max_num_crops": 448, | ||
"pan_and_scan_min_crop_size": 32, | ||
"pan_and_scan_min_ratio_to_activate": 0.3, | ||
} | ||
image_processor_dict = self.image_processor_dict | ||
image_processor_dict.update(crop_config) | ||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) | ||
image_processor_slow = self.image_processing_class(**image_processor_dict) | ||
image_processor_fast = self.fast_image_processing_class(**image_processor_dict) | ||
|
||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") | ||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") | ||
|
||
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops) | ||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll try to see if i can make it work but I've had some issues with |
||
self.assertLessEqual( | ||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR particularly. Seeing a second grouping to sizes while we are using a grouped image batch, leads me to believe the batching logic in fast processors are over-complicated. Would be nice if we can simplify stuff, especially for community contributors when they add a new model with a new processing like PaS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this doesn't look great agreed. This is quite a specific case where we start with images of potentially different sizes, then we split them in patches, and concatenate the patches with the original image (which is bigger than the patches), before resizing all to the same size 😅.
This new code is a bit overkill, where we group the patches and images by size at every step, but it's not really necessary to have at least in the first implementation by external contributors, so hopefully they won't ever have to do that to get a working fast image processor.
Not sure if there's a simpler way to fully use batch processing, I guess it's case by case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, no rush to work on that. I don't think we have been forcing users to add only fast processors for now. We can come back to this question later. Maybe we'll find a better way to batch or we add some guides about how to add special processing in fast image processors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
➕ on spending time to write simpler code! We can find some magic here and there.
In this specific case, I don't necessarily know. But, sometimes padding then unpadding can lead to better perfs.