Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f568bae

Browse files
authored
Merge pull request abetlen#351 from player1537-forks/th/add-logits-bias-parameter
Add support for `logit_bias` and `logit_bias_type` parameters
2 parents abf6d4a + eb7645b commit f568bae

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

llama_cpp/llama.py

+2
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ def create_chat_completion(
13781378
mirostat_tau: float = 5.0,
13791379
mirostat_eta: float = 0.1,
13801380
model: Optional[str] = None,
1381+
logits_processor: Optional[LogitsProcessorList] = None,
13811382
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
13821383
"""Generate a chat completion from a list of messages.
13831384
@@ -1419,6 +1420,7 @@ def create_chat_completion(
14191420
mirostat_tau=mirostat_tau,
14201421
mirostat_eta=mirostat_eta,
14211422
model=model,
1423+
logits_processor=logits_processor,
14221424
)
14231425
if stream:
14241426
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

llama_cpp/server/app.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
255255
)
256256
presence_penalty: Optional[float] = presence_penalty_field
257257
frequency_penalty: Optional[float] = frequency_penalty_field
258+
logit_bias: Optional[Dict[str, float]] = Field(None)
259+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
258260

259261
# ignored or currently unsupported
260262
model: Optional[str] = model_field
261263
n: Optional[int] = 1
262264
logprobs: Optional[int] = Field(None)
263265
best_of: Optional[int] = 1
264-
logit_bias: Optional[Dict[str, float]] = Field(None)
265266
user: Optional[str] = Field(None)
266267

267268
# llama.cpp specific parameters
@@ -280,6 +281,39 @@ class Config:
280281
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
281282

282283

284+
def make_logit_bias_processor(
285+
llama: llama_cpp.Llama,
286+
logit_bias: Dict[str, float],
287+
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
288+
):
289+
if logit_bias_type is None:
290+
logit_bias_type = "input_ids"
291+
292+
to_bias: Dict[int, float] = {}
293+
if logit_bias_type == "input_ids":
294+
for input_id, score in logit_bias.items():
295+
input_id = int(input_id)
296+
to_bias[input_id] = score
297+
298+
elif logit_bias_type == "tokens":
299+
for token, score in logit_bias.items():
300+
token = token.encode('utf-8')
301+
for input_id in llama.tokenize(token, add_bos=False):
302+
to_bias[input_id] = score
303+
304+
def logit_bias_processor(
305+
input_ids: List[int],
306+
scores: List[float],
307+
) -> List[float]:
308+
new_scores = [None] * len(scores)
309+
for input_id, score in enumerate(scores):
310+
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
311+
312+
return new_scores
313+
314+
return logit_bias_processor
315+
316+
283317
@router.post(
284318
"/v1/completions",
285319
response_model=CreateCompletionResponse,
@@ -297,9 +331,16 @@ async def create_completion(
297331
"n",
298332
"best_of",
299333
"logit_bias",
334+
"logit_bias_type",
300335
"user",
301336
}
302337
kwargs = body.dict(exclude=exclude)
338+
339+
if body.logit_bias is not None:
340+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
341+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
342+
])
343+
303344
if body.stream:
304345
send_chan, recv_chan = anyio.create_memory_object_stream(10)
305346

@@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
378419
stream: bool = stream_field
379420
presence_penalty: Optional[float] = presence_penalty_field
380421
frequency_penalty: Optional[float] = frequency_penalty_field
422+
logit_bias: Optional[Dict[str, float]] = Field(None)
423+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
381424

382425
# ignored or currently unsupported
383426
model: Optional[str] = model_field
384427
n: Optional[int] = 1
385-
logit_bias: Optional[Dict[str, float]] = Field(None)
386428
user: Optional[str] = Field(None)
387429

388430
# llama.cpp specific parameters
@@ -419,9 +461,16 @@ async def create_chat_completion(
419461
exclude = {
420462
"n",
421463
"logit_bias",
464+
"logit_bias_type",
422465
"user",
423466
}
424467
kwargs = body.dict(exclude=exclude)
468+
469+
if body.logit_bias is not None:
470+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
471+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
472+
])
473+
425474
if body.stream:
426475
send_chan, recv_chan = anyio.create_memory_object_stream(10)
427476

0 commit comments

Comments
 (0)