Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 94fe4bc

Browse files
committed
Add function calling support
1 parent fd55c29 commit 94fe4bc

File tree

1 file changed

+72
-23
lines changed

1 file changed

+72
-23
lines changed

llama_cpp/llama_chat_format.py

+72-23
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,7 @@ def __call__(
22322232
typical_p: float = 1.0,
22332233
stream: bool = False,
22342234
stop: Optional[Union[str, List[str]]] = [],
2235+
seed: Optional[int] = None,
22352236
response_format: Optional[
22362237
llama_types.ChatCompletionRequestResponseFormat
22372238
] = None,
@@ -2246,6 +2247,9 @@ def __call__(
22462247
model: Optional[str] = None,
22472248
logits_processor: Optional[llama.LogitsProcessorList] = None,
22482249
grammar: Optional[llama.LlamaGrammar] = None,
2250+
logit_bias: Optional[Dict[str, float]] = None,
2251+
logprobs: Optional[bool] = None,
2252+
top_logprobs: Optional[int] = None,
22492253
**kwargs, # type: ignore
22502254
) -> Union[
22512255
llama_types.CreateChatCompletionResponse,
@@ -2309,32 +2313,77 @@ def free_embed():
23092313
if response_format is not None and response_format["type"] == "json_object":
23102314
grammar = _grammar_for_response_format(response_format)
23112315

2312-
# TODO: Add function call support
2316+
# Convert legacy functions to tools
2317+
if functions is not None:
2318+
tools = [
2319+
{
2320+
"type": "function",
2321+
"function": function,
2322+
}
2323+
for function in functions
2324+
]
23132325

2314-
return _convert_completion_to_chat(
2315-
llama.create_completion(
2316-
prompt=prompt,
2317-
temperature=temperature,
2318-
top_p=top_p,
2319-
top_k=top_k,
2320-
min_p=min_p,
2321-
typical_p=typical_p,
2322-
stream=stream,
2323-
stop=stop,
2324-
max_tokens=max_tokens,
2325-
presence_penalty=presence_penalty,
2326-
frequency_penalty=frequency_penalty,
2327-
repeat_penalty=repeat_penalty,
2328-
tfs_z=tfs_z,
2329-
mirostat_mode=mirostat_mode,
2330-
mirostat_tau=mirostat_tau,
2331-
mirostat_eta=mirostat_eta,
2332-
model=model,
2333-
logits_processor=logits_processor,
2334-
grammar=grammar,
2335-
),
2326+
# Convert legacy function_call to tool_choice
2327+
if function_call is not None:
2328+
if isinstance(function_call, str) and (
2329+
function_call == "none" or function_call == "auto"
2330+
):
2331+
tool_choice = function_call
2332+
if isinstance(function_call, dict) and "name" in function_call:
2333+
tool_choice = {
2334+
"type": "function",
2335+
"function": {
2336+
"name": function_call["name"],
2337+
},
2338+
}
2339+
2340+
tool = None
2341+
if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None:
2342+
name = tool_choice["function"]["name"]
2343+
tool = next((t for t in tools if t["function"]["name"] == name), None)
2344+
if tool is None:
2345+
raise ValueError(f"Tool choice '{name}' not found in tools.")
2346+
schema = tool["function"]["parameters"]
2347+
try:
2348+
# create grammar from json schema
2349+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
2350+
json.dumps(schema), verbose=llama.verbose
2351+
)
2352+
except Exception as e:
2353+
grammar = llama_grammar.LlamaGrammar.from_string(
2354+
llama_grammar.JSON_GBNF, verbose=llama.verbose
2355+
)
2356+
2357+
completion_or_chunks = llama.create_completion(
2358+
prompt=prompt,
2359+
temperature=temperature,
2360+
top_p=top_p,
2361+
top_k=top_k,
2362+
min_p=min_p,
2363+
typical_p=typical_p,
2364+
logprobs=top_logprobs if logprobs else None,
23362365
stream=stream,
2366+
stop=stop,
2367+
seed=seed,
2368+
max_tokens=max_tokens,
2369+
presence_penalty=presence_penalty,
2370+
frequency_penalty=frequency_penalty,
2371+
repeat_penalty=repeat_penalty,
2372+
tfs_z=tfs_z,
2373+
mirostat_mode=mirostat_mode,
2374+
mirostat_tau=mirostat_tau,
2375+
mirostat_eta=mirostat_eta,
2376+
model=model,
2377+
logits_processor=logits_processor,
2378+
grammar=grammar,
2379+
logit_bias=logit_bias,
23372380
)
2381+
if tool is not None:
2382+
tool_name = tool["function"]["name"]
2383+
return _convert_completion_to_chat_function(
2384+
tool_name, completion_or_chunks, stream
2385+
)
2386+
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
23382387

23392388
@staticmethod
23402389
def _load_image(image_url: str) -> bytes:

0 commit comments

Comments
 (0)