From 23ec9747a0b4238e1364fb3dccd0021de1b65f38 Mon Sep 17 00:00:00 2001 From: Oussama Elachqar Date: Mon, 2 Jun 2025 14:48:23 -0700 Subject: [PATCH 1/2] update --- .../collators/vision_language_sft_collator.py | 214 ++++++++++++++++-- 1 file changed, 191 insertions(+), 23 deletions(-) diff --git a/src/oumi/core/collators/vision_language_sft_collator.py b/src/oumi/core/collators/vision_language_sft_collator.py index 91e5d02a86..617b27b1e4 100644 --- a/src/oumi/core/collators/vision_language_sft_collator.py +++ b/src/oumi/core/collators/vision_language_sft_collator.py @@ -12,6 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Vision-Language SFT collator for conversation-based multimodal training. + +This module provides a collator specifically designed for supervised fine-tuning (SFT) +of vision-language models using conversation data. + +Unlike VisionLanguageCollatorWithPadding which expects pre-processed features, +this collator works with raw conversation objects and handles the complete feature +generation pipeline. + +Example: + >>> from oumi.builders import build_tokenizer + >>> from oumi.core.configs import ModelParams + >>> tokenizer = build_tokenizer(ModelParams(model_name="llava-hf/llava-1.5-7b-hf")) + >>> collator = VisionLanguageSftCollator( + ... tokenizer=tokenizer, + ... processor_name="llava-hf/llava-1.5-7b-hf", + ... max_length=512, + ... truncation=True + ... ) + >>> # Expects batch items with conversation_json field + >>> batch = collator([{"conversation_json": conversation1.to_json()}, ...]) +""" + from typing import Any, Optional from oumi.core.feature_generators import ( @@ -20,9 +43,31 @@ ) from oumi.core.tokenizers.base_tokenizer import BaseTokenizer from oumi.core.types import Conversation +from oumi.utils.torch_utils import pad_to_max_dim_and_stack class VisionLanguageSftCollator: + """Collator for vision-language SFT that processes conversation data. + + This collator is designed for supervised fine-tuning of vision-language models + where training data comes in the form of conversations containing both text and + images. It handles the complete pipeline from raw conversations to model-ready + tensor batches. + + Key Features: + - Processes Conversation objects containing text and image data + - Uses model-specific processors to extract image features + - Handles tokenization and feature generation in one step + - Supports various vision-language architectures + - Manages padding, truncation, and label masking + + The collator expects batch items with a "conversation_json" field containing + serialized Conversation objects. These conversations can include: + - Multiple turns of dialogue + - Image references (paths, URLs, or base64 data) + - System prompts and user/assistant messages + """ + def __init__( self, tokenizer: BaseTokenizer, @@ -35,27 +80,55 @@ def __init__( label_ignore_index: Optional[int] = None, allow_multi_image_inputs: bool = True, trust_remote_code: bool = False, + process_individually: bool = False, ): - """Custom collator for multi-modal vision-language training. + """Initializes the vision-language SFT collator. Args: - tokenizer: The tokenizer used for encoding the data. - processor_name: The name of the processor to use for feature generation. - processor_kwargs: A dictionary of processor-specific parameters. - These parameters are passed to the processor constructor. - They can override model-specific parameters. - max_length: Padding length. - truncation: Whether to truncate long inputs to `max_length`. - If False, the long inputs are preserved as is even if they exceed - `max_length`. Only has effect if `max_length` is specified. - truncation_side: The side to truncate the tokens ("right" or "left"). - label_ignore_index: If set, then label values of tokens that shouldn't - contribute to the loss computation will be replaced by - this special value. - allow_multi_image_inputs: Whether to allow multi-image inputs. - trust_remote_code: Whether to trust remote code execution for the processor. + tokenizer: The tokenizer for encoding text. Should match the model's + tokenizer for proper token alignment. + + processor_name: Name or path of the processor to use for feature extraction. + This should typically match the model name. + The processor handles image preprocessing and feature extraction. + + processor_kwargs: Optional parameters to pass to the processor constructor. + These can override default settings or model-specific parameters. + + max_length: Maximum sequence length for padding/truncation. If None, + sequences are padded to the batch maximum. If specified, sequences + are padded to this length and may be truncated. + + truncation: Whether to truncate sequences exceeding max_length. + If False, long sequences are kept intact. Only applies when + max_length is specified. + + truncation_side: Which side to truncate from ("right" or "left"). + Most models use "right" truncation, but some may require "left" + for specific architectures or tasks. + + label_ignore_index: Value to use for masking labels in loss computation. + + allow_multi_image_inputs: Whether to support multiple images per + conversation. + Set to True for models like MLLaMA that handle multiple images. + Set to False for models that only support single images per example. + + trust_remote_code: Whether to trust and execute remote code when loading + the processor. Required for some models (e.g., Qwen2-VL) that use + custom processing code. + + process_individually: Whether to process each conversation individually + and then collate features by padding to max dimensions. When True: + - Each conversation is processed separately through the feature + generator + - Features are padded to the maximum size in the batch + - Useful for models with variable-sized outputs or heterogeneous data + - May be less efficient but more flexible than batch processing + When False (default), conversations are processed as a batch. """ self._allow_multi_image_inputs = allow_multi_image_inputs + self._process_individually = process_individually if not processor_name: raise ValueError("processor_name is required for VisionLanguageSftCollator") @@ -75,13 +148,51 @@ def __init__( ) def __call__(self, batch) -> dict[str, Any]: - """Custom collator for multi-modal vision-language training. + """Process a batch of conversation data into model-ready features. + + This method converts serialized conversations into the tensor format expected + by vision-language models. It handles the complete pipeline: + 1. Deserializes conversation JSON strings + 2. Passes conversations to the feature generator + 3. Returns batched tensors ready for training Args: - batch: List of batch items. + batch: List of dictionaries, where each dictionary must contain a + "conversation_json" field with a serialized Conversation object. + + Expected format: + [ + {"conversation_json": '{"messages": [...], "images": [...]}'}, + {"conversation_json": '{"messages": [...], "images": [...]}'}, + ... + ] + + The conversation JSON should include: + - messages: List of message dictionaries with role and content + - images: Optional list of image data (paths, URLs, or base64) Returns: - Dict[str, torch.Tensor]: Processed batch. + Dictionary containing all features needed for model training: + - "input_ids": Token IDs including image placeholders + - "attention_mask": Attention masks for the input + - "labels": Target labels with appropriate masking + - "pixel_values" or model-specific image features + - Additional model-specific features (cross_attention_mask, etc.) + + The exact keys depend on the model architecture and processor used. + + Raises: + ValueError: If batch is empty or any item lacks "conversation_json" field. + + Example: + >>> conversation = Conversation(messages=[ + ... {"role": "user", "content": "What's in this image?"}, + ... {"role": "assistant", "content": "I see a cat."} + ... ], images=["path/to/image.jpg"]) + >>> batch_item = {"conversation_json": conversation.to_json()} + >>> features = collator([batch_item]) + >>> print(features.keys()) + dict_keys(['input_ids', 'attention_mask', 'labels', 'pixel_values']) """ batch_size = len(batch) if batch_size <= 0: @@ -101,9 +212,66 @@ def __call__(self, batch) -> dict[str, Any]: conversations.append(Conversation.from_json(conversation_json)) assert len(conversations) == batch_size - result = self._conversation_feature_generator.transform_conversations( - conversations, - FeatureGeneratorOptions(allow_feature_reshape=False), - ) + if self._process_individually: + individual_results = [] + for conversation in conversations: + single_result = ( + self._conversation_feature_generator.transform_conversations( + [conversation], + FeatureGeneratorOptions(allow_feature_reshape=False), + ) + ) + individual_results.append(single_result) + + # Collate features by padding to max dimensions + result = self._collate_individual_results(individual_results) + else: + result = self._conversation_feature_generator.transform_conversations( + conversations, + FeatureGeneratorOptions(allow_feature_reshape=False), + ) return result + + def _collate_individual_results( + self, results: list[dict[str, Any]] + ) -> dict[str, Any]: + """Collate individually processed results by padding to max dimensions. + + Args: + results: List of feature dictionaries from individual conversation + processing + + Returns: + Collated dictionary with padded tensors + + Raises: + ValueError: If results have inconsistent keys or non-tensor values + """ + if not results: + return {} + + # Get keys from first result and verify consistency + expected_keys = set(results[0].keys()) + for i, result in enumerate(results[1:], 1): + if set(result.keys()) != expected_keys: + raise ValueError( + f"Inconsistent keys in batch. Expected {expected_keys}, " + f"but result {i} has {set(result.keys())}" + ) + + # Collate each feature + collated = {} + for key in expected_keys: + values = [result[key] for result in results] + + # Determine max variable dimensions based on feature type + # For multi-image models, we may need 2 variable dims (num_images, seq_len) + max_var_dims = 2 if self._allow_multi_image_inputs else 1 + + # Pad and stack tensors + collated[key] = pad_to_max_dim_and_stack( + values, max_variable_sized_dims=max_var_dims + ) + + return collated From 86521746671785596c0e2cbd4057f56383dd2ff1 Mon Sep 17 00:00:00 2001 From: Oussama Elachqar Date: Mon, 2 Jun 2025 15:33:32 -0700 Subject: [PATCH 2/2] update --- src/oumi/core/collators/vision_language_sft_collator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/oumi/core/collators/vision_language_sft_collator.py b/src/oumi/core/collators/vision_language_sft_collator.py index 617b27b1e4..c32d9c03d7 100644 --- a/src/oumi/core/collators/vision_language_sft_collator.py +++ b/src/oumi/core/collators/vision_language_sft_collator.py @@ -248,12 +248,12 @@ def _collate_individual_results( Raises: ValueError: If results have inconsistent keys or non-tensor values """ - if not results: + if not results or len(results) == 0: return {} # Get keys from first result and verify consistency expected_keys = set(results[0].keys()) - for i, result in enumerate(results[1:], 1): + for i, result in enumerate(results): if set(result.keys()) != expected_keys: raise ValueError( f"Inconsistent keys in batch. Expected {expected_keys}, "