Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit eb7645b

Browse files
committed
Add support for logit_bias and logit_bias_type parameters
1 parent 0da655b commit eb7645b

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

llama_cpp/llama.py

+2
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,7 @@ def create_chat_completion(
13801380
mirostat_tau: float = 5.0,
13811381
mirostat_eta: float = 0.1,
13821382
model: Optional[str] = None,
1383+
logits_processor: Optional[LogitsProcessorList] = None,
13831384
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
13841385
"""Generate a chat completion from a list of messages.
13851386
@@ -1421,6 +1422,7 @@ def create_chat_completion(
14211422
mirostat_tau=mirostat_tau,
14221423
mirostat_eta=mirostat_eta,
14231424
model=model,
1425+
logits_processor=logits_processor,
14241426
)
14251427
if stream:
14261428
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

llama_cpp/server/app.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,14 @@ class CreateCompletionRequest(BaseModel):
249249
)
250250
presence_penalty: Optional[float] = presence_penalty_field
251251
frequency_penalty: Optional[float] = frequency_penalty_field
252+
logit_bias: Optional[Dict[str, float]] = Field(None)
253+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
252254

253255
# ignored or currently unsupported
254256
model: Optional[str] = model_field
255257
n: Optional[int] = 1
256258
logprobs: Optional[int] = Field(None)
257259
best_of: Optional[int] = 1
258-
logit_bias: Optional[Dict[str, float]] = Field(None)
259260
user: Optional[str] = Field(None)
260261

261262
# llama.cpp specific parameters
@@ -274,6 +275,39 @@ class Config:
274275
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
275276

276277

278+
def make_logit_bias_processor(
279+
llama: llama_cpp.Llama,
280+
logit_bias: Dict[str, float],
281+
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
282+
):
283+
if logit_bias_type is None:
284+
logit_bias_type = "input_ids"
285+
286+
to_bias: Dict[int, float] = {}
287+
if logit_bias_type == "input_ids":
288+
for input_id, score in logit_bias.items():
289+
input_id = int(input_id)
290+
to_bias[input_id] = score
291+
292+
elif logit_bias_type == "tokens":
293+
for token, score in logit_bias.items():
294+
token = token.encode('utf-8')
295+
for input_id in llama.tokenize(token, add_bos=False):
296+
to_bias[input_id] = score
297+
298+
def logit_bias_processor(
299+
input_ids: List[int],
300+
scores: List[float],
301+
) -> List[float]:
302+
new_scores = [None] * len(scores)
303+
for input_id, score in enumerate(scores):
304+
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
305+
306+
return new_scores
307+
308+
return logit_bias_processor
309+
310+
277311
@router.post(
278312
"/v1/completions",
279313
response_model=CreateCompletionResponse,
@@ -291,9 +325,16 @@ async def create_completion(
291325
"n",
292326
"best_of",
293327
"logit_bias",
328+
"logit_bias_type",
294329
"user",
295330
}
296331
kwargs = body.dict(exclude=exclude)
332+
333+
if body.logit_bias is not None:
334+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
335+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
336+
])
337+
297338
if body.stream:
298339
send_chan, recv_chan = anyio.create_memory_object_stream(10)
299340

@@ -372,11 +413,12 @@ class CreateChatCompletionRequest(BaseModel):
372413
stream: bool = stream_field
373414
presence_penalty: Optional[float] = presence_penalty_field
374415
frequency_penalty: Optional[float] = frequency_penalty_field
416+
logit_bias: Optional[Dict[str, float]] = Field(None)
417+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
375418

376419
# ignored or currently unsupported
377420
model: Optional[str] = model_field
378421
n: Optional[int] = 1
379-
logit_bias: Optional[Dict[str, float]] = Field(None)
380422
user: Optional[str] = Field(None)
381423

382424
# llama.cpp specific parameters
@@ -413,9 +455,16 @@ async def create_chat_completion(
413455
exclude = {
414456
"n",
415457
"logit_bias",
458+
"logit_bias_type",
416459
"user",
417460
}
418461
kwargs = body.dict(exclude=exclude)
462+
463+
if body.logit_bias is not None:
464+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
465+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
466+
])
467+
419468
if body.stream:
420469
send_chan, recv_chan = anyio.create_memory_object_stream(10)
421470

0 commit comments

Comments
 (0)