Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Issues on GRPO with VLM and vLLM (cont.)Β #4488

@Fhrozen

Description

@Fhrozen

Reproduction

@qgallouedec Thank you for your hard work on #4113

New issues:

Test Code

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("custom_with multiple images in prompt ", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Llava7b-GRPO")
trainer = GRPOTrainer(
    model="[Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/llava-hf/llava-1.5-7b-hf)",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Issue 1 - Cannot train GRPO with multiple images and Llava-1.5b:

  • Output: ValueError: Image features and image tokens do not match: tokens: 3456, features 2359296
  • Reason: Missing image_grid_thw variable in batch input skips the split processing on split_pixel_values_by_grid and make the image size == 1 during _get_per_token_logps_and_entropies
  • My implementation:

on trl.trainer.utils:

def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]:
    """
    Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`
    and batch["num_images"] while keeping other entries unchanged.
    """
    if "pixel_values" not in batch or "num_images" not in batch:
        return batch  # type: ignore

    if "image_grid_thw" not in batch:
        lengths = sum(batch["num_images"]) * [1]
    else:
        lengths = batch["image_grid_thw"].prod(-1).tolist()  # [num_images]
    pixel_values = batch["pixel_values"]  # [total, feature_dim]

    if sum(lengths) != pixel_values.size(0):
        raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}")

    boundaries = [0, *accumulate(batch["num_images"])]  # [3, 4, 5] -> [0, 3, 7, 12]
    sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))]
    split_values = list(torch.split(batch["pixel_values"], sections, dim=0))  # type: ignore

    if "image_grid_thw" in batch:
        image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0))  # type: ignore
        return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw}

    return {**batch, "pixel_values": split_values}

on trl.trainer.grpo_trainer at GRPOTrainer._get_per_token_logps_and_entropies

elif pixel_values is not None:
              cum_imgs = torch.tensor([0] + num_images).cumsum(0)  # type: ignore
              img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
              model_inputs["pixel_values"] = pixel_values[img_start:img_end]

Issue 2 - Cannot train with vLLM serve

  • Output:
    output = self.vllm_client.generate(
  File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 240, in generate
    images = [pil_to_base64(img) for img in images] if images else None
  File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 240, in <listcomp>
    images = [pil_to_base64(img) for img in images] if images else None
  File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 235, in pil_to_base64
    image.save(buffer, format="PNG")
AttributeError: 'list' object has no attribute 'save'
  • Reason:
    def pil_to_base64(image):
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    img_bytes = buffer.getvalue()
    return base64.b64encode(img_bytes).decode("utf-8")
    # Convert PIL images to base64 strings
    images = [pil_to_base64(img) for img in images] if images else None
    expects a list of images not a list of list of images such as in multiple images.
  • Also:
    for prompt, image in zip(request.prompts, request.images):
    row = {"prompt": prompt}
    if image is not None:
    row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
    prompts.append(row)
    seems to expects a multimodal with single image.
  • solutions, working on.

Issue 3 - Gradient Checkpointing

  • Cannot train with:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false

displays Gradients will be none. And use_reentrant== true raise has been marked as ready twice error. By the moment I disabled gradient checkpointing, but I will need it in the future, so I check it.

System Info

  • Platform: Linux-6.14.0-1015-x86_64-with-glibc2.39
  • Python version: 3.10.18
  • TRL version: 0.24.0
  • PyTorch version: 2.8.0
  • accelerator(s): NVIDIA L40S, NVIDIA L40S, NVIDIA L40S, NVIDIA L40S
  • Transformers version: 4.57.1
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • Datasets version: 4.3.0
  • HF Hub version: 0.36.0
  • bitsandbytes version: 0.47.0
  • DeepSpeed version: 0.17.4
  • Liger-Kernel version: 0.6.1
  • LLM-Blender version: not installed
  • OpenAI version: 1.100.2
  • PEFT version: 0.17.1
  • vLLM version: 0.11.0

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions