@@ -249,13 +249,14 @@ class CreateCompletionRequest(BaseModel):
249
249
)
250
250
presence_penalty : Optional [float ] = presence_penalty_field
251
251
frequency_penalty : Optional [float ] = frequency_penalty_field
252
+ logit_bias : Optional [Dict [str , float ]] = Field (None )
253
+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
252
254
253
255
# ignored or currently unsupported
254
256
model : Optional [str ] = model_field
255
257
n : Optional [int ] = 1
256
258
logprobs : Optional [int ] = Field (None )
257
259
best_of : Optional [int ] = 1
258
- logit_bias : Optional [Dict [str , float ]] = Field (None )
259
260
user : Optional [str ] = Field (None )
260
261
261
262
# llama.cpp specific parameters
@@ -274,6 +275,39 @@ class Config:
274
275
CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
275
276
276
277
278
+ def make_logit_bias_processor (
279
+ llama : llama_cpp .Llama ,
280
+ logit_bias : Dict [str , float ],
281
+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]],
282
+ ):
283
+ if logit_bias_type is None :
284
+ logit_bias_type = "input_ids"
285
+
286
+ to_bias : Dict [int , float ] = {}
287
+ if logit_bias_type == "input_ids" :
288
+ for input_id , score in logit_bias .items ():
289
+ input_id = int (input_id )
290
+ to_bias [input_id ] = score
291
+
292
+ elif logit_bias_type == "tokens" :
293
+ for token , score in logit_bias .items ():
294
+ token = token .encode ('utf-8' )
295
+ for input_id in llama .tokenize (token , add_bos = False ):
296
+ to_bias [input_id ] = score
297
+
298
+ def logit_bias_processor (
299
+ input_ids : List [int ],
300
+ scores : List [float ],
301
+ ) -> List [float ]:
302
+ new_scores = [None ] * len (scores )
303
+ for input_id , score in enumerate (scores ):
304
+ new_scores [input_id ] = score + to_bias .get (input_id , 0.0 )
305
+
306
+ return new_scores
307
+
308
+ return logit_bias_processor
309
+
310
+
277
311
@router .post (
278
312
"/v1/completions" ,
279
313
response_model = CreateCompletionResponse ,
@@ -291,9 +325,16 @@ async def create_completion(
291
325
"n" ,
292
326
"best_of" ,
293
327
"logit_bias" ,
328
+ "logit_bias_type" ,
294
329
"user" ,
295
330
}
296
331
kwargs = body .dict (exclude = exclude )
332
+
333
+ if body .logit_bias is not None :
334
+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
335
+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
336
+ ])
337
+
297
338
if body .stream :
298
339
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
299
340
@@ -372,11 +413,12 @@ class CreateChatCompletionRequest(BaseModel):
372
413
stream : bool = stream_field
373
414
presence_penalty : Optional [float ] = presence_penalty_field
374
415
frequency_penalty : Optional [float ] = frequency_penalty_field
416
+ logit_bias : Optional [Dict [str , float ]] = Field (None )
417
+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
375
418
376
419
# ignored or currently unsupported
377
420
model : Optional [str ] = model_field
378
421
n : Optional [int ] = 1
379
- logit_bias : Optional [Dict [str , float ]] = Field (None )
380
422
user : Optional [str ] = Field (None )
381
423
382
424
# llama.cpp specific parameters
@@ -413,9 +455,16 @@ async def create_chat_completion(
413
455
exclude = {
414
456
"n" ,
415
457
"logit_bias" ,
458
+ "logit_bias_type" ,
416
459
"user" ,
417
460
}
418
461
kwargs = body .dict (exclude = exclude )
462
+
463
+ if body .logit_bias is not None :
464
+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
465
+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
466
+ ])
467
+
419
468
if body .stream :
420
469
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
421
470
0 commit comments