Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bb65b4d

Browse files
committed
fix: pass correct type to chat handlers for chat completion logprobs
1 parent 060bfa6 commit bb65b4d

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

llama_cpp/llama.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1664,7 +1664,8 @@ def create_chat_completion(
16641664
top_k=top_k,
16651665
min_p=min_p,
16661666
typical_p=typical_p,
1667-
logprobs=top_logprobs if logprobs else None,
1667+
logprobs=logprobs,
1668+
top_logprobs=top_logprobs,
16681669
stream=stream,
16691670
stop=stop,
16701671
seed=seed,

llama_cpp/llama_chat_format.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __call__(
7777
mirostat_eta: float = 0.1,
7878
logits_processor: Optional[llama.LogitsProcessorList] = None,
7979
grammar: Optional[llama.LlamaGrammar] = None,
80+
logprobs: Optional[bool] = None,
81+
top_logprobs: Optional[int] = None,
8082
**kwargs, # type: ignore
8183
) -> Union[
8284
llama_types.CreateChatCompletionResponse,
@@ -338,7 +340,7 @@ def _convert_completion_to_chat_function(
338340
}
339341
],
340342
},
341-
"logprobs": None,
343+
"logprobs": completion["choices"][0]["logprobs"],
342344
"finish_reason": "tool_calls",
343345
}
344346
],
@@ -391,7 +393,7 @@ def _stream_response_to_function_stream(
391393
{
392394
"index": 0,
393395
"finish_reason": None,
394-
"logprobs": None,
396+
"logprobs": chunk["choices"][0]["logprobs"],
395397
"delta": {
396398
"role": None,
397399
"content": None,
@@ -426,7 +428,7 @@ def _stream_response_to_function_stream(
426428
{
427429
"index": 0,
428430
"finish_reason": None,
429-
"logprobs": None,
431+
"logprobs": chunk["choices"][0]["logprobs"],
430432
"delta": {
431433
"role": None,
432434
"content": None,
@@ -491,7 +493,6 @@ def chat_completion_handler(
491493
temperature: float = 0.2,
492494
top_p: float = 0.95,
493495
top_k: int = 40,
494-
logprobs: int = 0,
495496
min_p: float = 0.05,
496497
typical_p: float = 1.0,
497498
stream: bool = False,
@@ -512,6 +513,8 @@ def chat_completion_handler(
512513
logits_processor: Optional[llama.LogitsProcessorList] = None,
513514
grammar: Optional[llama.LlamaGrammar] = None,
514515
logit_bias: Optional[Dict[str, float]] = None,
516+
logprobs: Optional[bool] = None,
517+
top_logprobs: Optional[int] = None,
515518
**kwargs, # type: ignore
516519
) -> Union[
517520
llama_types.CreateChatCompletionResponse,
@@ -581,7 +584,7 @@ def chat_completion_handler(
581584
top_k=top_k,
582585
min_p=min_p,
583586
typical_p=typical_p,
584-
logprobs=logprobs,
587+
logprobs=top_logprobs if logprobs else None,
585588
stream=stream,
586589
stop=stop,
587590
seed=seed,
@@ -1628,7 +1631,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
16281631
}
16291632
],
16301633
},
1631-
"logprobs": None,
1634+
"logprobs": completion["choices"][0]["logprobs"],
16321635
"finish_reason": "tool_calls",
16331636
}
16341637
],
@@ -2085,7 +2088,7 @@ def create_completion(stop):
20852088
choices=[
20862089
{
20872090
"index": 0,
2088-
"logprobs": None,
2091+
"logprobs": completion["choices"][0]["logprobs"],
20892092
"message": {
20902093
"role": "assistant",
20912094
"content": None if content == "" else content,
@@ -2311,11 +2314,14 @@ def chatml_function_calling(
23112314
model: Optional[str] = None,
23122315
logits_processor: Optional[llama.LogitsProcessorList] = None,
23132316
grammar: Optional[llama.LlamaGrammar] = None,
2317+
logprobs: Optional[bool] = None,
2318+
top_logprobs: Optional[int] = None,
23142319
**kwargs, # type: ignore
23152320
) -> Union[
23162321
llama_types.CreateChatCompletionResponse,
23172322
Iterator[llama_types.CreateChatCompletionStreamResponse],
23182323
]:
2324+
print(logprobs)
23192325
function_calling_template = (
23202326
"{% for message in messages %}"
23212327
"<|im_start|>{{ message.role }}\n"
@@ -2437,6 +2443,7 @@ def chatml_function_calling(
24372443
model=model,
24382444
logits_processor=logits_processor,
24392445
grammar=grammar,
2446+
logprobs=top_logprobs if logprobs else None,
24402447
),
24412448
stream=stream,
24422449
)
@@ -2549,6 +2556,7 @@ def chatml_function_calling(
25492556
typical_p=typical_p,
25502557
stream=stream,
25512558
stop=["<|im_end|>"],
2559+
logprobs=top_logprobs if logprobs else None,
25522560
max_tokens=None,
25532561
presence_penalty=presence_penalty,
25542562
frequency_penalty=frequency_penalty,
@@ -2660,7 +2668,7 @@ def chatml_function_calling(
26602668
{
26612669
"finish_reason": "tool_calls",
26622670
"index": 0,
2663-
"logprobs": None,
2671+
"logprobs": completion["choices"][0]["logprobs"],
26642672
"message": {
26652673
"role": "assistant",
26662674
"content": None,

0 commit comments

Comments
 (0)