Thanks to visit codestin.com
Credit goes to Github.com

Skip to content

Conversation

@albertvillanova
Copy link
Member

Refactor vLLM generation [1/N]: Extract vLLM generation.

This PR introduces a new initialization module for vLLM generation in TRL trainers. The main change is the addition of conditional support for the VLLMGeneration backend, which will only be imported and exposed if the vllm dependency is available.

@albertvillanova albertvillanova marked this pull request as ready for review December 16, 2025 05:39
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 107 to 119
# TODO: improve
# Calculate max_model_len: use vllm_max_model_length if available, otherwise compute from prompt+completion
if hasattr(args, "vllm_max_model_length") and args.vllm_max_model_length is not None:
max_model_len = args.vllm_max_model_length
elif (
hasattr(self.trainer, "max_prompt_length")
and self.trainer.max_prompt_length is not None
and hasattr(self.trainer, "max_completion_length")
and self.trainer.max_completion_length is not None
):
max_model_len = self.trainer.max_prompt_length + self.trainer.max_completion_length
else:
max_model_len = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not be necessary once we merge:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed it after the merge.

Comment on lines 147 to 150
# TODO: improve
# Add logprobs_mode only for GRPO (not used in RLOO)
if "grpo" in type(self.trainer).__name__.lower():
llm_kwargs["logprobs_mode"] = "processed_logprobs"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also thinking how to improve this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass logprobs_mode in the RLOOTrainer as well, after having read the motivation of the PR that introduced this:

I am reverting this condition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Args:
trainer: Reference to parent trainer for accessing config, model, accelerator, etc.
"""
self.trainer = trainer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having a class dedicated to generation for vLLM now that the amount of code related to generation is starting to saturate the online trainers.

However I'm a bit annoyed by this back-reference. It's not an ideal design choice in my opinion. In this state, I even have trouble seeing how to properly test this.

Conceptually, the generator doesn't need to have a trainer to generation. It only needs the generation parameters, probably the accelerator, and an method to update the weights.

Can we think of a alternative design, something like

class VLLMGeneration:
    def __init__(self, model_id, accelerator, mode="server", ...):
        ...

    def sync_weights(self, model):
        ...

    def generate(self, prompts, temperature, ...):
        ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for you insightful review! 🤗

Yes, I totally agree with you that the back-reference to the trainer is not ideal.

Indeed, this PR is a preliminary step in the complete refactoring and I was planning to change that on a later PR. This was my original plan:

  1. PR 1 (current): Extract the vLLM logic into a separate class (prove the concept, remove duplication)
  2. PR 2: Refine the interface to use explicit parameters instead of trainer reference
  3. PR 3: Add a proper protocol/interface

However, I could fix this now in this PR if you think it is better.

Copy link
Member Author

@albertvillanova albertvillanova Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen that the trainer instance is used by the rollout_func. This will need further refactoring to separate (but coordinate) rollout_func and vllm functionalities. 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen that the trainer instance is used by the rollout_func. This will need further refactoring to separate (but coordinate) rollout_func and vllm functionalities. 🤔

Yes, indeed. The design of rollout_func is really not ideal. We implemented it that way initially because we had to move quickly for the release of OpenEnv, but in my opinion it's really something we should completely rethink and refactor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this PR is a preliminary step in the complete refactoring and I was planning to change that on a later PR. This was my original plan:

I'm fine with this multi-stage PR for this. It will probably be easier to review this way, and for you to write

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, I see that the trainer instance is also used by the profiling_context function. This will need further refactoring as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a dedicated PR to decouple profiling from trainer:

@albertvillanova
Copy link
Member Author

As this PR was not approved, I have finally addressed both issues here:

  • trainer reference within rollout_func
  • trainer reference within profiling_context

This PR needs merging this PR first though:

@albertvillanova
Copy link
Member Author

I have integrated PR #4717.
CC: @qgallouedec

def __init__(
self,
model,
args,
Copy link
Member

@qgallouedec qgallouedec Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think of removing args and have instead

  • enable_sleep_mode
  • gpu_memory_utilization
  • group_port
  • guided_decoding_regex
  • max_model_length
  • mode
  • model_impl
  • server_base_url
  • server_host
  • server_timeout
  • tensor_parallel_size

in the argument of this function?

you may want to add max_num_seqs also

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @qgallouedec.

I had something similar in mind in my original plan:

  • PR 2: Refine the interface
    • I had in mind to create something like a VLLMGenerationConfig class with all present (and open for future) VLLM parameters.

If you prefer, I can do all the refactoring in a single PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec I think I addressed your suggestion above.

Copy link
Member

@qgallouedec qgallouedec Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessarily for this PR, but since we add a generation submodule, we could have vllm_client in this submodule as well

@albertvillanova
Copy link
Member Author

Oh, thanks a lot for your detailed review!!! 🤗
I will address all your requests.

Comment on lines 228 to 246
# Build LLM initialization kwargs
llm_kwargs = {
"model": model.name_or_path,
"tensor_parallel_size": self.tensor_parallel_size,
"gpu_memory_utilization": self.gpu_memory_utilization,
"max_num_seqs": self.max_num_seqs,
"max_model_len": self.max_model_length,
"distributed_executor_backend": "external_launcher",
# Feed identical seed for tp groups to ensure sampling results are the same across workers
"seed": accelerator.process_index // self.tensor_parallel_size,
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
"max_num_batched_tokens": 4096,
"model_impl": self.model_impl,
"enable_sleep_mode": self.enable_sleep_mode,
# Important so temperature scaling/logit tweaking affects the TIS log probs
"logprobs_mode": "processed_logprobs",
"quantization": quantization,
}
self.llm = LLM(**llm_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

instead of

llm_kwargs = {
    "model": model.name_or_path,
    ...
}
self.llm = LLM(**llm_kwargs)

we can also simply do

self.llm = LLM(
    model=model.name_or_path,
    ...,
)

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! This will simplify the trainers a lot

@albertvillanova albertvillanova merged commit 0eb66d8 into huggingface:main Jan 27, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants