📽 Multi image support for GRPO/RLOO#4113
Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| ) | ||
| trainer = GRPOTrainer( | ||
| model=model_id, | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
There was a problem hiding this comment.
we don't support visual reward model, so it doesn't really make sense to test this case, where the image is dropped and a warning is raised.
trl/trainer/grpo_trainer.py
Outdated
| # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed | ||
| for prompt in prompts: | ||
| for turn in prompt: | ||
| if isinstance(turn["content"], list): | ||
| logger.warning_once("Visual reward models aren't supported yet; dropping image.") | ||
| turn["content"] = " ".join( | ||
| e["text"] for e in turn["content"] if e["type"] == "text" | ||
| ) |
There was a problem hiding this comment.
from
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]to
[{"role": "user", "content": "What color is the sky?"}]plus raise warning
| # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for | ||
| # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the | ||
| # VLM chat template. | ||
| original_prompts = copy.deepcopy(prompts) |
There was a problem hiding this comment.
instead of keeping the original prompt, we just drop the image later, and raise a warning, see https://github.com/huggingface/trl/pull/4113/files#r2364899902
| # important because rewards will be normalized per group, and completions are distributed. We will later slice | ||
| # rewards_per_func to extract each process's subset. | ||
| rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) | ||
| rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) |
There was a problem hiding this comment.
| if self._logs["images"]: | ||
| table["images"] = [] | ||
| for image_list in self._logs["images"]: | ||
| # Convert images to wandb Image objects for proper visualization | ||
| table["images"].append([wandb.Image(image) for image in image_list]) |
| boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] | ||
| sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] | ||
| split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) | ||
| image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) | ||
| return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} |
There was a problem hiding this comment.
instead of keeping image_grid_thw as is, we need to split it depending on the number of images. It gets concatenated later in _get_per_token_logps_and_entropies (see line 807)
trl/trainer/grpo_trainer.py
Outdated
| model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) | ||
| start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() | ||
| end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() |
There was a problem hiding this comment.
See https://github.com/huggingface/trl/pull/4113/files#r2364904060, image_grid_thw is not a tensor anymore, but a list of tensor
lewtun
left a comment
There was a problem hiding this comment.
LGTM with a question about whether raising an error vs a warning is best when images + text are being passed to the reward function
|
|
||
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
|
||
| for n, param in previous_trainable_params.items(): |
There was a problem hiding this comment.
Does the same comment for GRPO apply here? https://github.com/huggingface/trl/pull/4113/files#diff-96dca172e696190fc3e1469166e88aface95ebae959284c6806f2e25d2217c16R1587
trl/trainer/grpo_trainer.py
Outdated
| for prompt in prompts: | ||
| for turn in prompt: | ||
| if isinstance(turn["content"], list): | ||
| logger.warning_once("Visual reward models aren't supported yet; dropping image.") |
There was a problem hiding this comment.
Would raising an error be better than a warning? Otherwise I could imagine the warning could be missed and the training "fails silently" because the reward is only computed on the text part.
There was a problem hiding this comment.
Yes, I see. I wonder if anyone would want to train a VLM with a standard LM reward model (ie, not visual reward model). But so far, I've never seen that. We could always support it in the future if there is demand for it. I'll remove this warning, and if the user tries it, the rendering of the chat template will fail, which will prevent from ending in the case of the training failing silently that you describe.
| table["images"] = [] | ||
| for image_list in self._logs["images"]: | ||
| # Convert images to wandb Image objects for proper visualization | ||
| table["images"].append([wandb.Image(image) for image in image_list]) |
There was a problem hiding this comment.
At some point it would be nice to also add the trackio variant for table images
This PR belongs to a sequence of PR that aims to refactor the generation part of GRPO/RLOO to allow for easier customization and ultimately tool calling
Previous:
image_split_sizesin favour ofimage_grid_thw#4111Next:
_generate#4114_generatein GRPO/RLOO: list of ints instead of tensors #4146_generatein GRPO/RLOO: Useprompt_idsfrom generation #4152_generatein GRPO/RLOO: Rely on generator for prompt truncation #4153_generatein GRPO/RLOO: Moveforward_kwargsoutside generation method #4154_generatein GRPO/RLOO: Insert images in the prompt #4155While refactoring, I realized that having a clean multi-image support help having a cleaner separation between functions.
try with