Thanks to visit codestin.com
Credit goes to github.com

Skip to content
75 changes: 35 additions & 40 deletions src/transformers/models/gemma3/image_processing_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_nested_list_of_images,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
Expand Down Expand Up @@ -334,9 +334,9 @@ def preprocess(
else self.pan_and_scan_min_ratio_to_activate
)

images_list = make_nested_list_of_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images_list[0]):
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
Expand All @@ -353,58 +353,53 @@ def preprocess(
resample=resample,
)
if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
images = [to_numpy_array(image) for image in images]

if do_rescale and is_scaled_image(images_list[0][0]):
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
input_data_format = infer_channel_dimension_format(images[0])

if do_pan_and_scan:
images_list_and_num_crops = [
self._process_images_for_pan_and_scan(
images=images,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
for images in images_list
]
images_list = [images for images, _ in images_list_and_num_crops]
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
images, num_crops = self._process_images_for_pan_and_scan(
images=images,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)

else:
num_crops = [[0] for _ in images_list]
num_crops = [0 for _ in images]

processed_images = []
for images in images_list:
for image in images:
if do_resize:
height, width = size["height"], size["width"]
image = resize(
image=image, size=(height, width), resample=resample, input_data_format=input_data_format
)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
processed_images.append(image)
for image in images:
if do_resize:
height, width = size["height"], size["width"]
image = resize(
image=image, size=(height, width), resample=resample, input_data_format=input_data_format
)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
processed_images.append(image)

data = {"pixel_values": processed_images, "num_crops": num_crops}
return BatchFeature(data=data, tensor_type=return_tensors)
Expand Down
143 changes: 56 additions & 87 deletions src/transformers/models/gemma3/image_processing_gemma3_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import itertools
import math
from functools import partial
from typing import List, Optional, Union

from ...image_processing_utils_fast import (
Expand All @@ -31,11 +30,8 @@
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
SizeDict,
get_image_size,
make_nested_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import (
Expand Down Expand Up @@ -103,52 +99,9 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorKwargs]):
super().__init__(**kwargs)

def _prepare_images_structure(
def pan_and_scan_batched(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare the images structure for processing.

Args:
images (`ImageInput`):
The input images to process.

Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_nested_list_of_images(images)

def _prepare_input_images(
self,
images: ImageInput,
do_convert_rgb: bool = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> List["torch.Tensor"]:
"""
Prepare the input images for processing.
"""
batch_images = self._prepare_images_structure(images)
process_image_fn = partial(
self._process_image,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# todo: yoni - check if we can parallelize this efficiently
batch_processed_images = []
for image_list in batch_images:
processed_images = []
for image in image_list:
processed_images.append(process_image_fn(image))
batch_processed_images.append(processed_images)

return batch_processed_images

def pan_and_scan(
self,
image: "torch.Tensor",
images: "torch.Tensor",
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
Expand All @@ -167,7 +120,7 @@ def pan_and_scan(
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
"""
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
height, width = images.shape[-2:]

# Square or landscape image.
if width >= height:
Expand Down Expand Up @@ -210,7 +163,7 @@ def pan_and_scan(
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]

return [
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
images[..., pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
]

Expand All @@ -222,18 +175,14 @@ def _process_images_for_pan_and_scan(
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
):
pas_images_list = []
num_crops = []
for image in images:
pas_images = self.pan_and_scan(
image=image,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
)
pas_images_list.extend([image] + pas_images)
num_crops.append(len(pas_images))
return pas_images_list, num_crops
pas_images = self.pan_and_scan_batched(
images=images,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
)
num_crops = [len(pas_images) for _ in images]
return pas_images, num_crops

@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
Expand Down Expand Up @@ -274,46 +223,66 @@ def _preprocess(
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
batch_num_crops = []

for images_list in images:
# Group images by size for batched processing
processed_images_grouped = {}
num_crops_grouped = {}
grouped_images, grouped_images_index = group_images_by_shape(images)
for shape_images, stacked_images in grouped_images.items():
if do_pan_and_scan:
images_list, num_crops = self._process_images_for_pan_and_scan(
images=images_list,
pas_images, num_crops = self._process_images_for_pan_and_scan(
images=stacked_images,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
)
# Add the thumbnails to the image patches
stacked_images = [stacked_images] + pas_images
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together)
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(stacked_images)
for shape, stacked_image_patches in grouped_image_patches.items():
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
# Transpose to have the thumbnails with their corresponding patches
stacked_images = torch.stack(processed_image_patches, dim=0).transpose(0, 1).contiguous()
Comment on lines +239 to +253
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR particularly. Seeing a second grouping to sizes while we are using a grouped image batch, leads me to believe the batching logic in fast processors are over-complicated. Would be nice if we can simplify stuff, especially for community contributors when they add a new model with a new processing like PaS

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this doesn't look great agreed. This is quite a specific case where we start with images of potentially different sizes, then we split them in patches, and concatenate the patches with the original image (which is bigger than the patches), before resizing all to the same size 😅.
This new code is a bit overkill, where we group the patches and images by size at every step, but it's not really necessary to have at least in the first implementation by external contributors, so hopefully they won't ever have to do that to get a working fast image processor.
Not sure if there's a simpler way to fully use batch processing, I guess it's case by case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, no rush to work on that. I don't think we have been forcing users to add only fast processors for now. We can come back to this question later. Maybe we'll find a better way to batch or we add some guides about how to add special processing in fast image processors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

➕ on spending time to write simpler code! We can find some magic here and there.
In this specific case, I don't necessarily know. But, sometimes padding then unpadding can lead to better perfs.

  • Ordering can be costly
  • Depending on the size of padding, we are not really adding too much compute
  • distribution of image size is important to take into account!

else:
num_crops = [[0] for _ in images_list]
num_crops = [0 for _ in stacked_images]

# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
stacked_images = self.resize(
image=stacked_images,
size=size,
interpolation=interpolation,
)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
num_crops_grouped[shape_images] = num_crops
processed_images_grouped[shape_images] = stacked_images
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
# If pan and scan is enabled, we need to flatten the list of images
if do_pan_and_scan:
resized_images = [image for images_list in resized_images for image in images_list]
num_crops = reorder_images(num_crops_grouped, grouped_images_index)

# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images.extend(processed_image_patches)
batch_num_crops.extend(num_crops)
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors
data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
)


Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/gemma3/processing_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def __call__(
)

# Replace image tokens by the full expanded sequence
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
num_crops = to_py_obj(image_inputs.pop("num_crops"))
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]

Expand All @@ -139,7 +140,7 @@ def __call__(
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")

# Add token type ids manually, as tokenizer can't do arbitrary position token types
array_ids = np.array(text_inputs["input_ids"])
array_ids = text_inputs["input_ids"]
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
Expand Down
41 changes: 41 additions & 0 deletions tests/models/gemma3/test_image_processing_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ def test_pan_and_scan(self):
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

# Test batched unbalanced, 9 images because we have base image + 2 crops per each item
encoded_images = image_processing(
[[image_inputs[0], image_inputs[1]], [image_inputs[2]]], return_tensors="pt"
).pixel_values
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
Expand Down Expand Up @@ -250,3 +257,37 @@ def test_call_pytorch(self):
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
def test_call_numpy_4_channels(self):
pass

@require_vision
@require_torch
def test_slow_fast_equivalence_batched_pas(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")

if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")

if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
crop_config = {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
image_processor_dict = self.image_processor_dict
image_processor_dict.update(crop_config)
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**image_processor_dict)

encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")

torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use torch.testing.assert_close here also, afair that works better for tensor match tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to see if i can make it work but I've had some issues with torch.testing.assert_close when comparing pixel values, because some rtol can be very high and it's difficult to choose a rtol value that will work for all processors (the base tests comparing slow and fast also don't use torch.testing.assert_close to compare pixel values)

self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
Loading