💬 Add chat to vLLM client and server, update trainer calls#4450
💬 Add chat to vLLM client and server, update trainer calls#4450qgallouedec merged 7 commits intomainfrom
Conversation
| ] | ||
| return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs} | ||
|
|
||
| class ChatRequest(BaseModel): |
There was a problem hiding this comment.
Exactly the same as generate, expect:
- images are within the messages (so we drop images from args)
- chat_template_kwargs argument added
| # FIXME: this endpoint doesn't exist in vllm_client | ||
| output = self.vllm_client.chat( | ||
| prompts=ordered_set_of_prompts, | ||
| messages=ordered_set_of_prompts, |
There was a problem hiding this comment.
I use "messages" instead of "prompt" to align with vLLM
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks: solid implementation! Just some comments and minor suggestions below.
lewtun
left a comment
There was a problem hiding this comment.
LGTM with a suggestion to double-check models like Llama are not getting a double BOS token.
| for seq in completion_ids: | ||
| assert all(isinstance(tok, int) for tok in seq) | ||
|
|
||
| def test_chat(self): |
There was a problem hiding this comment.
It would be good to check that the issues with double BOS tokens getting inserted have been fully resolved (e.g. for a Llama model): vllm-project/vllm#9519
@edbeeching ran into this during https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute and it has a subtle, but negative impact on the generations.
Co-authored-by: Albert Villanova del Moral <[email protected]>
Co-authored-by: Albert Villanova del Moral <[email protected]>
| for seq in completion_ids: | ||
| assert all(isinstance(tok, int) for tok in seq) | ||
|
|
||
| def test_chat(self): |
There was a problem hiding this comment.
@qgallouedec I'm sorry but I'm not able to run this test.
Could you please give me some hint about the environment requirements so I can run it?
Thanks! 🤗
There was a problem hiding this comment.
We might have to mock the response?
Slow tests pass locally, and GRPO training works as well: