Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c4f2bc1

Browse files
committed
server: fix OpenAI API compatibility for usage statistics in chat streams
1 parent a094f38 commit c4f2bc1

File tree

2 files changed

+58
-42
lines changed

2 files changed

+58
-42
lines changed

tools/server/server.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,17 @@ struct server_task_result_cmpl_final : server_task_result {
911911
{"model", oaicompat_model},
912912
{"system_fingerprint", build_info},
913913
{"object", "chat.completion.chunk"},
914+
});
915+
916+
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
917+
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
918+
deltas.push_back({
919+
{"choices", json::array()},
920+
{"created", t},
921+
{"id", oaicompat_cmpl_id},
922+
{"model", oaicompat_model},
923+
{"system_fingerprint", build_info},
924+
{"object", "chat.completion.chunk"},
914925
{"usage", json {
915926
{"completion_tokens", n_decoded},
916927
{"prompt_tokens", n_prompt_tokens},

tools/server/tests/unit/test_chat_completion.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -72,27 +72,29 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
7272
content = ""
7373
last_cmpl_id = None
7474
for i, data in enumerate(res):
75-
choice = data["choices"][0]
76-
if i == 0:
77-
# Check first role message for stream=True
78-
assert choice["delta"]["content"] is None
79-
assert choice["delta"]["role"] == "assistant"
75+
if data["choices"]:
76+
choice = data["choices"][0]
77+
if i == 0:
78+
# Check first role message for stream=True
79+
assert choice["delta"]["content"] is None
80+
assert choice["delta"]["role"] == "assistant"
81+
else:
82+
assert "role" not in choice["delta"]
83+
assert data["system_fingerprint"].startswith("b")
84+
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
85+
if last_cmpl_id is None:
86+
last_cmpl_id = data["id"]
87+
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
88+
if choice["finish_reason"] in ["stop", "length"]:
89+
assert "content" not in choice["delta"]
90+
assert match_regex(re_content, content)
91+
assert choice["finish_reason"] == finish_reason
92+
else:
93+
assert choice["finish_reason"] is None
94+
content += choice["delta"]["content"] or ''
8095
else:
81-
assert "role" not in choice["delta"]
82-
assert data["system_fingerprint"].startswith("b")
83-
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
84-
if last_cmpl_id is None:
85-
last_cmpl_id = data["id"]
86-
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
87-
if choice["finish_reason"] in ["stop", "length"]:
8896
assert data["usage"]["prompt_tokens"] == n_prompt
8997
assert data["usage"]["completion_tokens"] == n_predicted
90-
assert "content" not in choice["delta"]
91-
assert match_regex(re_content, content)
92-
assert choice["finish_reason"] == finish_reason
93-
else:
94-
assert choice["finish_reason"] is None
95-
content += choice["delta"]["content"] or ''
9698

9799

98100
def test_chat_completion_with_openai_library():
@@ -278,12 +280,14 @@ def test_chat_completion_with_timings_per_token():
278280
assert data["choices"][0]["delta"]["role"] == "assistant"
279281
assert "timings" not in data, f'First event should not have timings: {data}'
280282
else:
281-
assert "role" not in data["choices"][0]["delta"]
282-
assert "timings" in data
283-
assert "prompt_per_second" in data["timings"]
284-
assert "predicted_per_second" in data["timings"]
285-
assert "predicted_n" in data["timings"]
286-
assert data["timings"]["predicted_n"] <= 10
283+
if data["choices"]:
284+
assert "role" not in data["choices"][0]["delta"]
285+
else:
286+
assert "timings" in data
287+
assert "prompt_per_second" in data["timings"]
288+
assert "predicted_per_second" in data["timings"]
289+
assert "predicted_n" in data["timings"]
290+
assert data["timings"]["predicted_n"] <= 10
287291

288292

289293
def test_logprobs():
@@ -332,24 +336,25 @@ def test_logprobs_stream():
332336
output_text = ''
333337
aggregated_text = ''
334338
for i, data in enumerate(res):
335-
choice = data.choices[0]
336-
if i == 0:
337-
# Check first role message for stream=True
338-
assert choice.delta.content is None
339-
assert choice.delta.role == "assistant"
340-
else:
341-
assert choice.delta.role is None
342-
if choice.finish_reason is None:
343-
if choice.delta.content:
344-
output_text += choice.delta.content
345-
assert choice.logprobs is not None
346-
assert choice.logprobs.content is not None
347-
for token in choice.logprobs.content:
348-
aggregated_text += token.token
349-
assert token.logprob <= 0.0
350-
assert token.bytes is not None
351-
assert token.top_logprobs is not None
352-
assert len(token.top_logprobs) > 0
339+
if data.choices:
340+
choice = data.choices[0]
341+
if i == 0:
342+
# Check first role message for stream=True
343+
assert choice.delta.content is None
344+
assert choice.delta.role == "assistant"
345+
else:
346+
assert choice.delta.role is None
347+
if choice.finish_reason is None:
348+
if choice.delta.content:
349+
output_text += choice.delta.content
350+
assert choice.logprobs is not None
351+
assert choice.logprobs.content is not None
352+
for token in choice.logprobs.content:
353+
aggregated_text += token.token
354+
assert token.logprob <= 0.0
355+
assert token.bytes is not None
356+
assert token.top_logprobs is not None
357+
assert len(token.top_logprobs) > 0
353358
assert aggregated_text == output_text
354359

355360

0 commit comments

Comments
 (0)