Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 07e47f5

Browse files
committed
Add support for logit_bias outside of server api. Closes abetlen#827
1 parent c21edb6 commit 07e47f5

File tree

3 files changed

+44
-38
lines changed

3 files changed

+44
-38
lines changed

llama_cpp/llama.py

+25
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,7 @@ def _create_completion(
13271327
stopping_criteria: Optional[StoppingCriteriaList] = None,
13281328
logits_processor: Optional[LogitsProcessorList] = None,
13291329
grammar: Optional[LlamaGrammar] = None,
1330+
logit_bias: Optional[Dict[int, float]] = None,
13301331
) -> Union[
13311332
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
13321333
]:
@@ -1355,6 +1356,28 @@ def _create_completion(
13551356
)
13561357
model_name: str = model if model is not None else self.model_path
13571358

1359+
# NOTE: This likely doesn't work correctly for the first token in the prompt
1360+
# because of the extra space added to the start of the prompt_tokens
1361+
if logit_bias is not None:
1362+
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
1363+
1364+
def logit_bias_processor(
1365+
input_ids: npt.NDArray[np.intc],
1366+
scores: npt.NDArray[np.single],
1367+
) -> npt.NDArray[np.single]:
1368+
new_scores = np.copy(
1369+
scores
1370+
) # Does it make sense to copy the whole array or can we just overwrite the original one?
1371+
for input_id, score in logit_bias_map.items():
1372+
new_scores[input_id] = score + scores[input_id]
1373+
return new_scores
1374+
1375+
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
1376+
if logits_processor is None:
1377+
logits_processor = _logit_bias_processor
1378+
else:
1379+
logits_processor = logits_processor.extend(_logit_bias_processor)
1380+
13581381
if self.verbose:
13591382
self._ctx.reset_timings()
13601383

@@ -1963,6 +1986,7 @@ def create_chat_completion(
19631986
model: Optional[str] = None,
19641987
logits_processor: Optional[LogitsProcessorList] = None,
19651988
grammar: Optional[LlamaGrammar] = None,
1989+
logit_bias: Optional[Dict[str, float]] = None,
19661990
) -> Union[
19671991
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
19681992
]:
@@ -2011,6 +2035,7 @@ def create_chat_completion(
20112035
model=model,
20122036
logits_processor=logits_processor,
20132037
grammar=grammar,
2038+
logit_bias=logit_bias,
20142039
)
20152040

20162041
def __getstate__(self):

llama_cpp/llama_chat_format.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __call__(
4545
model: Optional[str] = None,
4646
logits_processor: Optional[llama.LogitsProcessorList] = None,
4747
grammar: Optional[llama.LlamaGrammar] = None,
48+
logit_bias: Optional[Dict[str, float]] = None,
4849
**kwargs, # type: ignore
4950
) -> Union[
5051
llama_types.CreateChatCompletionResponse,
@@ -308,6 +309,7 @@ def basic_create_chat_completion(
308309
model: Optional[str] = None,
309310
logits_processor: Optional[llama.LogitsProcessorList] = None,
310311
grammar: Optional[llama.LlamaGrammar] = None,
312+
logit_bias: Optional[Dict[str, float]] = None,
311313
**kwargs, # type: ignore
312314
) -> Union[
313315
llama_types.CreateChatCompletionResponse,
@@ -350,6 +352,7 @@ def basic_create_chat_completion(
350352
model=model,
351353
logits_processor=logits_processor,
352354
grammar=grammar,
355+
logit_bias=logit_bias,
353356
)
354357
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
355358

llama_cpp/server/app.py

+16-38
Original file line numberDiff line numberDiff line change
@@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
646646
}
647647

648648

649-
def make_logit_bias_processor(
649+
def _logit_bias_tokens_to_input_ids(
650650
llama: llama_cpp.Llama,
651651
logit_bias: Dict[str, float],
652-
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
653-
):
654-
if logit_bias_type is None:
655-
logit_bias_type = "input_ids"
656-
657-
to_bias: Dict[int, float] = {}
658-
if logit_bias_type == "input_ids":
659-
for input_id, score in logit_bias.items():
660-
input_id = int(input_id)
661-
to_bias[input_id] = score
662-
663-
elif logit_bias_type == "tokens":
664-
for token, score in logit_bias.items():
665-
token = token.encode("utf-8")
666-
for input_id in llama.tokenize(token, add_bos=False, special=True):
667-
to_bias[input_id] = score
668-
669-
def logit_bias_processor(
670-
input_ids: npt.NDArray[np.intc],
671-
scores: npt.NDArray[np.single],
672-
) -> npt.NDArray[np.single]:
673-
new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one?
674-
for input_id, score in to_bias.items():
675-
new_scores[input_id] = score + scores[input_id]
676-
return new_scores
677-
678-
return logit_bias_processor
652+
) -> Dict[str, float]:
653+
to_bias: Dict[str, float] = {}
654+
for token, score in logit_bias.items():
655+
token = token.encode("utf-8")
656+
for input_id in llama.tokenize(token, add_bos=False, special=True):
657+
to_bias[str(input_id)] = score
658+
return to_bias
679659

680660

681661
@router.post(
@@ -694,17 +674,16 @@ async def create_completion(
694674
exclude = {
695675
"n",
696676
"best_of",
697-
"logit_bias",
698677
"logit_bias_type",
699678
"user",
700679
}
701680
kwargs = body.model_dump(exclude=exclude)
702681

703682
if body.logit_bias is not None:
704-
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
705-
[
706-
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
707-
]
683+
kwargs["logit_bias"] = (
684+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
685+
if body.logit_bias_type == "tokens"
686+
else body.logit_bias
708687
)
709688

710689
if body.grammar is not None:
@@ -851,17 +830,16 @@ async def create_chat_completion(
851830
) -> llama_cpp.ChatCompletion:
852831
exclude = {
853832
"n",
854-
"logit_bias",
855833
"logit_bias_type",
856834
"user",
857835
}
858836
kwargs = body.model_dump(exclude=exclude)
859837

860838
if body.logit_bias is not None:
861-
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
862-
[
863-
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
864-
]
839+
kwargs["logit_bias"] = (
840+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
841+
if body.logit_bias_type == "tokens"
842+
else body.logit_bias
865843
)
866844

867845
if body.grammar is not None:

0 commit comments

Comments
 (0)