-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Refactor vLLM generation [1/N]: Extract vLLM generation #4700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor vLLM generation [1/N]: Extract vLLM generation #4700
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
trl/generation/vllm_generation.py
Outdated
| # TODO: improve | ||
| # Calculate max_model_len: use vllm_max_model_length if available, otherwise compute from prompt+completion | ||
| if hasattr(args, "vllm_max_model_length") and args.vllm_max_model_length is not None: | ||
| max_model_len = args.vllm_max_model_length | ||
| elif ( | ||
| hasattr(self.trainer, "max_prompt_length") | ||
| and self.trainer.max_prompt_length is not None | ||
| and hasattr(self.trainer, "max_completion_length") | ||
| and self.trainer.max_completion_length is not None | ||
| ): | ||
| max_model_len = self.trainer.max_prompt_length + self.trainer.max_completion_length | ||
| else: | ||
| max_model_len = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not be necessary once we merge:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed it after the merge.
trl/generation/vllm_generation.py
Outdated
| # TODO: improve | ||
| # Add logprobs_mode only for GRPO (not used in RLOO) | ||
| if "grpo" in type(self.trainer).__name__.lower(): | ||
| llm_kwargs["logprobs_mode"] = "processed_logprobs" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also thinking how to improve this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should pass logprobs_mode in the RLOOTrainer as well, after having read the motivation of the PR that introduced this:
I am reverting this condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
trl/generation/vllm_generation.py
Outdated
| Args: | ||
| trainer: Reference to parent trainer for accessing config, model, accelerator, etc. | ||
| """ | ||
| self.trainer = trainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having a class dedicated to generation for vLLM now that the amount of code related to generation is starting to saturate the online trainers.
However I'm a bit annoyed by this back-reference. It's not an ideal design choice in my opinion. In this state, I even have trouble seeing how to properly test this.
Conceptually, the generator doesn't need to have a trainer to generation. It only needs the generation parameters, probably the accelerator, and an method to update the weights.
Can we think of a alternative design, something like
class VLLMGeneration:
def __init__(self, model_id, accelerator, mode="server", ...):
...
def sync_weights(self, model):
...
def generate(self, prompts, temperature, ...):
...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for you insightful review! 🤗
Yes, I totally agree with you that the back-reference to the trainer is not ideal.
Indeed, this PR is a preliminary step in the complete refactoring and I was planning to change that on a later PR. This was my original plan:
- PR 1 (current): Extract the vLLM logic into a separate class (prove the concept, remove duplication)
- PR 2: Refine the interface to use explicit parameters instead of trainer reference
- PR 3: Add a proper protocol/interface
However, I could fix this now in this PR if you think it is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have seen that the trainer instance is used by the rollout_func. This will need further refactoring to separate (but coordinate) rollout_func and vllm functionalities. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have seen that the
trainerinstance is used by therollout_func. This will need further refactoring to separate (but coordinate) rollout_func and vllm functionalities. 🤔
Yes, indeed. The design of rollout_func is really not ideal. We implemented it that way initially because we had to move quickly for the release of OpenEnv, but in my opinion it's really something we should completely rethink and refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this PR is a preliminary step in the complete refactoring and I was planning to change that on a later PR. This was my original plan:
I'm fine with this multi-stage PR for this. It will probably be easier to review this way, and for you to write
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, I see that the trainer instance is also used by the profiling_context function. This will need further refactoring as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created a dedicated PR to decouple profiling from trainer:
|
As this PR was not approved, I have finally addressed both issues here:
This PR needs merging this PR first though: |
|
I have integrated PR #4717. |
trl/generation/vllm_generation.py
Outdated
| def __init__( | ||
| self, | ||
| model, | ||
| args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think of removing args and have instead
enable_sleep_modegpu_memory_utilizationgroup_portguided_decoding_regexmax_model_lengthmodemodel_implserver_base_urlserver_hostserver_timeouttensor_parallel_size
in the argument of this function?
you may want to add max_num_seqs also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @qgallouedec.
I had something similar in mind in my original plan:
- PR 2: Refine the interface
- I had in mind to create something like a
VLLMGenerationConfigclass with all present (and open for future) VLLM parameters.
- I had in mind to create something like a
If you prefer, I can do all the refactoring in a single PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qgallouedec I think I addressed your suggestion above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not necessarily for this PR, but since we add a generation submodule, we could have vllm_client in this submodule as well
|
Oh, thanks a lot for your detailed review!!! 🤗 |
trl/generation/vllm_generation.py
Outdated
| # Build LLM initialization kwargs | ||
| llm_kwargs = { | ||
| "model": model.name_or_path, | ||
| "tensor_parallel_size": self.tensor_parallel_size, | ||
| "gpu_memory_utilization": self.gpu_memory_utilization, | ||
| "max_num_seqs": self.max_num_seqs, | ||
| "max_model_len": self.max_model_length, | ||
| "distributed_executor_backend": "external_launcher", | ||
| # Feed identical seed for tp groups to ensure sampling results are the same across workers | ||
| "seed": accelerator.process_index // self.tensor_parallel_size, | ||
| # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory | ||
| "max_num_batched_tokens": 4096, | ||
| "model_impl": self.model_impl, | ||
| "enable_sleep_mode": self.enable_sleep_mode, | ||
| # Important so temperature scaling/logit tweaking affects the TIS log probs | ||
| "logprobs_mode": "processed_logprobs", | ||
| "quantization": quantization, | ||
| } | ||
| self.llm = LLM(**llm_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
instead of
llm_kwargs = {
"model": model.name_or_path,
...
}
self.llm = LLM(**llm_kwargs)we can also simply do
self.llm = LLM(
model=model.name_or_path,
...,
)
qgallouedec
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! This will simplify the trainers a lot
Refactor vLLM generation [1/N]: Extract vLLM generation.
This PR introduces a new initialization module for vLLM generation in TRL trainers. The main change is the addition of conditional support for the
VLLMGenerationbackend, which will only be imported and exposed if thevllmdependency is available.