@@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
255
255
)
256
256
presence_penalty : Optional [float ] = presence_penalty_field
257
257
frequency_penalty : Optional [float ] = frequency_penalty_field
258
+ logit_bias : Optional [Dict [str , float ]] = Field (None )
259
+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
258
260
259
261
# ignored or currently unsupported
260
262
model : Optional [str ] = model_field
261
263
n : Optional [int ] = 1
262
264
logprobs : Optional [int ] = Field (None )
263
265
best_of : Optional [int ] = 1
264
- logit_bias : Optional [Dict [str , float ]] = Field (None )
265
266
user : Optional [str ] = Field (None )
266
267
267
268
# llama.cpp specific parameters
@@ -280,6 +281,39 @@ class Config:
280
281
CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
281
282
282
283
284
+ def make_logit_bias_processor (
285
+ llama : llama_cpp .Llama ,
286
+ logit_bias : Dict [str , float ],
287
+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]],
288
+ ):
289
+ if logit_bias_type is None :
290
+ logit_bias_type = "input_ids"
291
+
292
+ to_bias : Dict [int , float ] = {}
293
+ if logit_bias_type == "input_ids" :
294
+ for input_id , score in logit_bias .items ():
295
+ input_id = int (input_id )
296
+ to_bias [input_id ] = score
297
+
298
+ elif logit_bias_type == "tokens" :
299
+ for token , score in logit_bias .items ():
300
+ token = token .encode ('utf-8' )
301
+ for input_id in llama .tokenize (token , add_bos = False ):
302
+ to_bias [input_id ] = score
303
+
304
+ def logit_bias_processor (
305
+ input_ids : List [int ],
306
+ scores : List [float ],
307
+ ) -> List [float ]:
308
+ new_scores = [None ] * len (scores )
309
+ for input_id , score in enumerate (scores ):
310
+ new_scores [input_id ] = score + to_bias .get (input_id , 0.0 )
311
+
312
+ return new_scores
313
+
314
+ return logit_bias_processor
315
+
316
+
283
317
@router .post (
284
318
"/v1/completions" ,
285
319
response_model = CreateCompletionResponse ,
@@ -297,9 +331,16 @@ async def create_completion(
297
331
"n" ,
298
332
"best_of" ,
299
333
"logit_bias" ,
334
+ "logit_bias_type" ,
300
335
"user" ,
301
336
}
302
337
kwargs = body .dict (exclude = exclude )
338
+
339
+ if body .logit_bias is not None :
340
+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
341
+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
342
+ ])
343
+
303
344
if body .stream :
304
345
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
305
346
@@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
378
419
stream : bool = stream_field
379
420
presence_penalty : Optional [float ] = presence_penalty_field
380
421
frequency_penalty : Optional [float ] = frequency_penalty_field
422
+ logit_bias : Optional [Dict [str , float ]] = Field (None )
423
+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
381
424
382
425
# ignored or currently unsupported
383
426
model : Optional [str ] = model_field
384
427
n : Optional [int ] = 1
385
- logit_bias : Optional [Dict [str , float ]] = Field (None )
386
428
user : Optional [str ] = Field (None )
387
429
388
430
# llama.cpp specific parameters
@@ -419,9 +461,16 @@ async def create_chat_completion(
419
461
exclude = {
420
462
"n" ,
421
463
"logit_bias" ,
464
+ "logit_bias_type" ,
422
465
"user" ,
423
466
}
424
467
kwargs = body .dict (exclude = exclude )
468
+
469
+ if body .logit_bias is not None :
470
+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
471
+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
472
+ ])
473
+
425
474
if body .stream :
426
475
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
427
476
0 commit comments