Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 191 additions & 23 deletions src/oumi/core/collators/vision_language_sft_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Vision-Language SFT collator for conversation-based multimodal training.

This module provides a collator specifically designed for supervised fine-tuning (SFT)
of vision-language models using conversation data.

Unlike VisionLanguageCollatorWithPadding which expects pre-processed features,
this collator works with raw conversation objects and handles the complete feature
generation pipeline.

Example:
>>> from oumi.builders import build_tokenizer
>>> from oumi.core.configs import ModelParams
>>> tokenizer = build_tokenizer(ModelParams(model_name="llava-hf/llava-1.5-7b-hf"))
>>> collator = VisionLanguageSftCollator(
... tokenizer=tokenizer,
... processor_name="llava-hf/llava-1.5-7b-hf",
... max_length=512,
... truncation=True
... )
>>> # Expects batch items with conversation_json field
>>> batch = collator([{"conversation_json": conversation1.to_json()}, ...])
"""

from typing import Any, Optional

from oumi.core.feature_generators import (
Expand All @@ -20,9 +43,31 @@
)
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.core.types import Conversation
from oumi.utils.torch_utils import pad_to_max_dim_and_stack


class VisionLanguageSftCollator:
"""Collator for vision-language SFT that processes conversation data.

This collator is designed for supervised fine-tuning of vision-language models
where training data comes in the form of conversations containing both text and
images. It handles the complete pipeline from raw conversations to model-ready
tensor batches.

Key Features:
- Processes Conversation objects containing text and image data
- Uses model-specific processors to extract image features
- Handles tokenization and feature generation in one step
- Supports various vision-language architectures
- Manages padding, truncation, and label masking

The collator expects batch items with a "conversation_json" field containing
serialized Conversation objects. These conversations can include:
- Multiple turns of dialogue
- Image references (paths, URLs, or base64 data)
- System prompts and user/assistant messages
"""

def __init__(
self,
tokenizer: BaseTokenizer,
Expand All @@ -35,27 +80,55 @@ def __init__(
label_ignore_index: Optional[int] = None,
allow_multi_image_inputs: bool = True,
trust_remote_code: bool = False,
process_individually: bool = False,
):
"""Custom collator for multi-modal vision-language training.
"""Initializes the vision-language SFT collator.

Args:
tokenizer: The tokenizer used for encoding the data.
processor_name: The name of the processor to use for feature generation.
processor_kwargs: A dictionary of processor-specific parameters.
These parameters are passed to the processor constructor.
They can override model-specific parameters.
max_length: Padding length.
truncation: Whether to truncate long inputs to `max_length`.
If False, the long inputs are preserved as is even if they exceed
`max_length`. Only has effect if `max_length` is specified.
truncation_side: The side to truncate the tokens ("right" or "left").
label_ignore_index: If set, then label values of tokens that shouldn't
contribute to the loss computation will be replaced by
this special value.
allow_multi_image_inputs: Whether to allow multi-image inputs.
trust_remote_code: Whether to trust remote code execution for the processor.
tokenizer: The tokenizer for encoding text. Should match the model's
tokenizer for proper token alignment.

processor_name: Name or path of the processor to use for feature extraction.
This should typically match the model name.
The processor handles image preprocessing and feature extraction.

processor_kwargs: Optional parameters to pass to the processor constructor.
These can override default settings or model-specific parameters.

max_length: Maximum sequence length for padding/truncation. If None,
sequences are padded to the batch maximum. If specified, sequences
are padded to this length and may be truncated.

truncation: Whether to truncate sequences exceeding max_length.
If False, long sequences are kept intact. Only applies when
max_length is specified.

truncation_side: Which side to truncate from ("right" or "left").
Most models use "right" truncation, but some may require "left"
for specific architectures or tasks.

label_ignore_index: Value to use for masking labels in loss computation.

allow_multi_image_inputs: Whether to support multiple images per
conversation.
Set to True for models like MLLaMA that handle multiple images.
Set to False for models that only support single images per example.

trust_remote_code: Whether to trust and execute remote code when loading
the processor. Required for some models (e.g., Qwen2-VL) that use
custom processing code.

process_individually: Whether to process each conversation individually
and then collate features by padding to max dimensions. When True:
- Each conversation is processed separately through the feature
generator
- Features are padded to the maximum size in the batch
- Useful for models with variable-sized outputs or heterogeneous data
- May be less efficient but more flexible than batch processing
When False (default), conversations are processed as a batch.
"""
self._allow_multi_image_inputs = allow_multi_image_inputs
self._process_individually = process_individually

if not processor_name:
raise ValueError("processor_name is required for VisionLanguageSftCollator")
Expand All @@ -75,13 +148,51 @@ def __init__(
)

def __call__(self, batch) -> dict[str, Any]:
"""Custom collator for multi-modal vision-language training.
"""Process a batch of conversation data into model-ready features.

This method converts serialized conversations into the tensor format expected
by vision-language models. It handles the complete pipeline:
1. Deserializes conversation JSON strings
2. Passes conversations to the feature generator
3. Returns batched tensors ready for training

Args:
batch: List of batch items.
batch: List of dictionaries, where each dictionary must contain a
"conversation_json" field with a serialized Conversation object.

Expected format:
[
{"conversation_json": '{"messages": [...], "images": [...]}'},
{"conversation_json": '{"messages": [...], "images": [...]}'},
...
]

The conversation JSON should include:
- messages: List of message dictionaries with role and content
- images: Optional list of image data (paths, URLs, or base64)

Returns:
Dict[str, torch.Tensor]: Processed batch.
Dictionary containing all features needed for model training:
- "input_ids": Token IDs including image placeholders
- "attention_mask": Attention masks for the input
- "labels": Target labels with appropriate masking
- "pixel_values" or model-specific image features
- Additional model-specific features (cross_attention_mask, etc.)

The exact keys depend on the model architecture and processor used.

Raises:
ValueError: If batch is empty or any item lacks "conversation_json" field.

Example:
>>> conversation = Conversation(messages=[
... {"role": "user", "content": "What's in this image?"},
... {"role": "assistant", "content": "I see a cat."}
... ], images=["path/to/image.jpg"])
>>> batch_item = {"conversation_json": conversation.to_json()}
>>> features = collator([batch_item])
>>> print(features.keys())
dict_keys(['input_ids', 'attention_mask', 'labels', 'pixel_values'])
"""
batch_size = len(batch)
if batch_size <= 0:
Expand All @@ -101,9 +212,66 @@ def __call__(self, batch) -> dict[str, Any]:
conversations.append(Conversation.from_json(conversation_json))
assert len(conversations) == batch_size

result = self._conversation_feature_generator.transform_conversations(
conversations,
FeatureGeneratorOptions(allow_feature_reshape=False),
)
if self._process_individually:
individual_results = []
for conversation in conversations:
single_result = (
self._conversation_feature_generator.transform_conversations(
[conversation],
FeatureGeneratorOptions(allow_feature_reshape=False),
)
)
individual_results.append(single_result)

# Collate features by padding to max dimensions
result = self._collate_individual_results(individual_results)
else:
result = self._conversation_feature_generator.transform_conversations(
conversations,
FeatureGeneratorOptions(allow_feature_reshape=False),
)

return result

def _collate_individual_results(
self, results: list[dict[str, Any]]
) -> dict[str, Any]:
"""Collate individually processed results by padding to max dimensions.

Args:
results: List of feature dictionaries from individual conversation
processing

Returns:
Collated dictionary with padded tensors

Raises:
ValueError: If results have inconsistent keys or non-tensor values
"""
if not results or len(results) == 0:
return {}

# Get keys from first result and verify consistency
expected_keys = set(results[0].keys())
for i, result in enumerate(results):
if set(result.keys()) != expected_keys:
raise ValueError(
f"Inconsistent keys in batch. Expected {expected_keys}, "
f"but result {i} has {set(result.keys())}"
)

# Collate each feature
collated = {}
for key in expected_keys:
values = [result[key] for result in results]

# Determine max variable dimensions based on feature type
# For multi-image models, we may need 2 variable dims (num_images, seq_len)
max_var_dims = 2 if self._allow_multi_image_inputs else 1

# Pad and stack tensors
collated[key] = pad_to_max_dim_and_stack(
Copy link
Contributor

@nikg4 nikg4 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this simple padding logic isn't compatible with all VLM models. Consider adding a note about it

(delegating everything to processor was the original motivation for adding this collator)

values, max_variable_sized_dims=max_var_dims
)

return collated
Loading