@@ -2232,6 +2232,7 @@ def __call__(
2232
2232
typical_p : float = 1.0 ,
2233
2233
stream : bool = False ,
2234
2234
stop : Optional [Union [str , List [str ]]] = [],
2235
+ seed : Optional [int ] = None ,
2235
2236
response_format : Optional [
2236
2237
llama_types .ChatCompletionRequestResponseFormat
2237
2238
] = None ,
@@ -2246,6 +2247,9 @@ def __call__(
2246
2247
model : Optional [str ] = None ,
2247
2248
logits_processor : Optional [llama .LogitsProcessorList ] = None ,
2248
2249
grammar : Optional [llama .LlamaGrammar ] = None ,
2250
+ logit_bias : Optional [Dict [str , float ]] = None ,
2251
+ logprobs : Optional [bool ] = None ,
2252
+ top_logprobs : Optional [int ] = None ,
2249
2253
** kwargs , # type: ignore
2250
2254
) -> Union [
2251
2255
llama_types .CreateChatCompletionResponse ,
@@ -2309,32 +2313,77 @@ def free_embed():
2309
2313
if response_format is not None and response_format ["type" ] == "json_object" :
2310
2314
grammar = _grammar_for_response_format (response_format )
2311
2315
2312
- # TODO: Add function call support
2316
+ # Convert legacy functions to tools
2317
+ if functions is not None :
2318
+ tools = [
2319
+ {
2320
+ "type" : "function" ,
2321
+ "function" : function ,
2322
+ }
2323
+ for function in functions
2324
+ ]
2313
2325
2314
- return _convert_completion_to_chat (
2315
- llama .create_completion (
2316
- prompt = prompt ,
2317
- temperature = temperature ,
2318
- top_p = top_p ,
2319
- top_k = top_k ,
2320
- min_p = min_p ,
2321
- typical_p = typical_p ,
2322
- stream = stream ,
2323
- stop = stop ,
2324
- max_tokens = max_tokens ,
2325
- presence_penalty = presence_penalty ,
2326
- frequency_penalty = frequency_penalty ,
2327
- repeat_penalty = repeat_penalty ,
2328
- tfs_z = tfs_z ,
2329
- mirostat_mode = mirostat_mode ,
2330
- mirostat_tau = mirostat_tau ,
2331
- mirostat_eta = mirostat_eta ,
2332
- model = model ,
2333
- logits_processor = logits_processor ,
2334
- grammar = grammar ,
2335
- ),
2326
+ # Convert legacy function_call to tool_choice
2327
+ if function_call is not None :
2328
+ if isinstance (function_call , str ) and (
2329
+ function_call == "none" or function_call == "auto"
2330
+ ):
2331
+ tool_choice = function_call
2332
+ if isinstance (function_call , dict ) and "name" in function_call :
2333
+ tool_choice = {
2334
+ "type" : "function" ,
2335
+ "function" : {
2336
+ "name" : function_call ["name" ],
2337
+ },
2338
+ }
2339
+
2340
+ tool = None
2341
+ if tool_choice is not None and isinstance (tool_choice , dict ) and tools is not None :
2342
+ name = tool_choice ["function" ]["name" ]
2343
+ tool = next ((t for t in tools if t ["function" ]["name" ] == name ), None )
2344
+ if tool is None :
2345
+ raise ValueError (f"Tool choice '{ name } ' not found in tools." )
2346
+ schema = tool ["function" ]["parameters" ]
2347
+ try :
2348
+ # create grammar from json schema
2349
+ grammar = llama_grammar .LlamaGrammar .from_json_schema (
2350
+ json .dumps (schema ), verbose = llama .verbose
2351
+ )
2352
+ except Exception as e :
2353
+ grammar = llama_grammar .LlamaGrammar .from_string (
2354
+ llama_grammar .JSON_GBNF , verbose = llama .verbose
2355
+ )
2356
+
2357
+ completion_or_chunks = llama .create_completion (
2358
+ prompt = prompt ,
2359
+ temperature = temperature ,
2360
+ top_p = top_p ,
2361
+ top_k = top_k ,
2362
+ min_p = min_p ,
2363
+ typical_p = typical_p ,
2364
+ logprobs = top_logprobs if logprobs else None ,
2336
2365
stream = stream ,
2366
+ stop = stop ,
2367
+ seed = seed ,
2368
+ max_tokens = max_tokens ,
2369
+ presence_penalty = presence_penalty ,
2370
+ frequency_penalty = frequency_penalty ,
2371
+ repeat_penalty = repeat_penalty ,
2372
+ tfs_z = tfs_z ,
2373
+ mirostat_mode = mirostat_mode ,
2374
+ mirostat_tau = mirostat_tau ,
2375
+ mirostat_eta = mirostat_eta ,
2376
+ model = model ,
2377
+ logits_processor = logits_processor ,
2378
+ grammar = grammar ,
2379
+ logit_bias = logit_bias ,
2337
2380
)
2381
+ if tool is not None :
2382
+ tool_name = tool ["function" ]["name" ]
2383
+ return _convert_completion_to_chat_function (
2384
+ tool_name , completion_or_chunks , stream
2385
+ )
2386
+ return _convert_completion_to_chat (completion_or_chunks , stream = stream )
2338
2387
2339
2388
@staticmethod
2340
2389
def _load_image (image_url : str ) -> bytes :
0 commit comments