@@ -1830,27 +1830,35 @@ def prepare_messages_for_inference(
1830
1830
version : Literal ["v1" , "v2" ],
1831
1831
functions : Optional [List [llama_types .ChatCompletionFunctions ]] = None ,
1832
1832
tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
1833
+ tool_choice : Union [Dict , str ] = "auto" ,
1833
1834
):
1834
1835
all_messages : List [llama_types .ChatCompletionRequestMessage ] = []
1835
- if functions is not None :
1836
+ if tool_choice == "none" :
1836
1837
all_messages .append (
1837
1838
llama_types .ChatCompletionRequestSystemMessage (
1838
- role = "system" , content = generate_schema_from_functions (functions )
1839
+ role = "system" , content = generate_schema_from_functions ([] )
1839
1840
)
1840
1841
)
1841
- elif tools is not None :
1842
- all_messages .append (
1843
- llama_types .ChatCompletionRequestSystemMessage (
1844
- role = "system" ,
1845
- content = generate_schema_from_functions (
1846
- [
1847
- tool ["function" ]
1848
- for tool in tools
1849
- if tool ["type" ] == "function"
1850
- ]
1851
- ),
1842
+ else :
1843
+ if functions is not None :
1844
+ all_messages .append (
1845
+ llama_types .ChatCompletionRequestSystemMessage (
1846
+ role = "system" , content = generate_schema_from_functions (functions )
1847
+ )
1848
+ )
1849
+ elif tools is not None and tool_choice != "none" :
1850
+ all_messages .append (
1851
+ llama_types .ChatCompletionRequestSystemMessage (
1852
+ role = "system" ,
1853
+ content = generate_schema_from_functions (
1854
+ [
1855
+ tool ["function" ]
1856
+ for tool in tools
1857
+ if tool ["type" ] == "function"
1858
+ ]
1859
+ ),
1860
+ )
1852
1861
)
1853
- )
1854
1862
1855
1863
all_messages .append (
1856
1864
llama_types .ChatCompletionRequestSystemMessage (
@@ -1890,7 +1898,7 @@ def prepare_messages_for_inference(
1890
1898
function_call = "auto"
1891
1899
1892
1900
prompt = prepare_messages_for_inference (
1893
- messages , tokenizer , version , functions , tools
1901
+ messages , tokenizer , version , functions , tools , function_call
1894
1902
)
1895
1903
1896
1904
# If no tools/functions are provided
@@ -1987,17 +1995,12 @@ def create_completion(stop):
1987
1995
1988
1996
content = ""
1989
1997
function_calls , function_bodies = [], []
1998
+ completion_tokens = 0
1990
1999
1991
2000
if version == "v1" :
1992
2001
# If no or "auto" tool_choice/function_call
1993
2002
if isinstance (function_call , str ) and function_call == "auto" :
1994
2003
stops = ["\n " , END_ASSISTANT_TOKEN ]
1995
- # If tool_choice/function_call is "none"
1996
- elif isinstance (function_call , str ) and function_call == "none" :
1997
- prompt = prepare_messages_for_inference (
1998
- messages , tokenizer , version , [], []
1999
- )
2000
- stops = END_ASSISTANT_TOKEN
2001
2004
# If tool_choice/function_call is provided
2002
2005
elif isinstance (function_call , dict ):
2003
2006
prompt += f"{ START_FUNCTION_CALL_TOKEN } { function_call ['name' ]} :\n "
@@ -2011,12 +2014,15 @@ def create_completion(stop):
2011
2014
2012
2015
completion = create_completion (stop = stops )
2013
2016
completion_text = completion ["choices" ][0 ]["text" ]
2017
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2018
+
2014
2019
2015
2020
# If the generation does not involve a function call
2016
2021
if (
2017
2022
START_FUNCTION_CALL_TOKEN not in prompt
2018
2023
and START_FUNCTION_CALL_TOKEN not in completion_text
2019
2024
):
2025
+ completion ["usage" ]["completion_tokens" ] = completion_tokens
2020
2026
return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
2021
2027
# If the generation involves a function call in completion, generate the parameters
2022
2028
elif (
@@ -2034,30 +2040,22 @@ def create_completion(stop):
2034
2040
)
2035
2041
grammar = get_grammar (function_calls [- 1 ])
2036
2042
completion = create_completion (stop = END_FUNCTION_CALL_TOKEN )
2043
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2037
2044
function_bodies .append (completion ["choices" ][0 ]["text" ].strip ())
2038
2045
# If the prompt involves a function call, just append generated parameters to function_bodies
2039
2046
else :
2040
2047
function_bodies .append (completion_text .strip ())
2041
2048
else :
2042
- # If tool_choice/function_call is "none"
2043
- if isinstance (function_call , str ) and function_call == "none" :
2044
- prompt = (
2045
- prepare_messages_for_inference (messages , tokenizer , version , [], [])
2046
- + "all\n <|content|>"
2047
- )
2048
- stops = [STOP_TOKEN , FROM_TOKEN ]
2049
- completion = create_completion (stop = stops )
2050
- completion ["choices" ][0 ]["text" ] = completion ["choices" ][0 ]["text" ].strip ()
2051
- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
2052
2049
# If tool_choice/function_call is provided
2053
- elif isinstance (function_call , dict ):
2050
+ if isinstance (function_call , dict ):
2054
2051
prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
2055
2052
function_call = function_call ["name" ]
2056
2053
function_calls .append (function_call )
2057
2054
grammar = get_grammar (function_call )
2058
2055
stops = [STOP_TOKEN , FROM_TOKEN ]
2059
2056
completion = create_completion (stop = stops )
2060
2057
completion_text = completion ["choices" ][0 ]["text" ]
2058
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2061
2059
function_bodies .append (completion_text .strip ())
2062
2060
# If "auto" or no tool_choice/function_call
2063
2061
elif isinstance (function_call , str ) and function_call == "auto" :
@@ -2067,6 +2065,7 @@ def create_completion(stop):
2067
2065
stops = CONTENT_TOKEN
2068
2066
completion = create_completion (stop = stops )
2069
2067
completion_text = completion ["choices" ][0 ]["text" ]
2068
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2070
2069
function_name = completion_text .strip ()
2071
2070
if function_name == "all" :
2072
2071
prompt += "all\n <|content|>"
@@ -2079,12 +2078,23 @@ def create_completion(stop):
2079
2078
stops = [RECIPIENT_TOKEN , STOP_TOKEN ]
2080
2079
completion = create_completion (stop = stops )
2081
2080
completion_text = completion ["choices" ][0 ]["text" ]
2081
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2082
2082
if function_name == "all" :
2083
- content += completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " )
2083
+ if completion_text .endswith ("\n <|from|>assistant\n " ):
2084
+ content += completion_text [:- len ("\n <|from|>assistant\n " )]
2085
+ if completion_text .endswith ("\n <|from|> assistant\n " ):
2086
+ content += completion_text [- len ("\n <|from|> assistant\n " )]
2087
+ else :
2088
+ content += completion_text
2084
2089
content = content .lstrip ()
2085
2090
# Check whether the model wants to generate another turn
2086
2091
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text :
2087
- cleaned_completion_text = completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " ).strip ()
2092
+ if completion_text .endswith ("\n <|from|>assistant\n " ):
2093
+ cleaned_completion_text = completion_text [:- len ("\n <|from|>assistant\n " )].strip ()
2094
+ elif completion_text .endswith ("\n <|from|> assistant\n " ):
2095
+ cleaned_completion_text = completion_text [- len ("\n <|from|> assistant\n " )].strip ()
2096
+ else :
2097
+ cleaned_completion_text = completion_text .strip ()
2088
2098
prompt += f"{ cleaned_completion_text } \n <|from|>assistant\n <|recipient|>"
2089
2099
else :
2090
2100
break
@@ -2094,6 +2104,7 @@ def create_completion(stop):
2094
2104
prompt += completion_text .strip ()
2095
2105
grammar = None
2096
2106
completion = create_completion (stop = stops )
2107
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2097
2108
if "<|from|> assistant" in completion ["choices" ][0 ]["text" ] or "<|from|>assistant" in completion ["choices" ][0 ]["text" ]:
2098
2109
prompt += "\n <|from|>assistant\n <|recipient|>"
2099
2110
else :
@@ -2122,12 +2133,16 @@ def create_completion(stop):
2122
2133
)
2123
2134
2124
2135
# TODO: support stream mode
2125
- function_call_dict : Union [Dict [str , str ], Dict [Literal ["function_call" ], llama_types .ChatCompletionRequestAssistantMessageFunctionCall ]] = {
2126
- "function_call" : {
2127
- "name" : tool_calls [0 ]["function" ]["name" ],
2128
- "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2129
- }
2130
- } if len (tool_calls ) == 1 else {}
2136
+ function_call_dict : Union [Dict [str , str ], Dict [Literal ["function_call" ], llama_types .ChatCompletionRequestAssistantMessageFunctionCall ]] = {}
2137
+ if len (tool_calls ) > 0 :
2138
+ if tools is not None :
2139
+ function_call_dict ["tool_calls" ] = tool_calls
2140
+ else :
2141
+ function_call_dict ["function_call" ] = {
2142
+ "name" : tool_calls [0 ]["function" ]["name" ],
2143
+ "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2144
+ }
2145
+ completion ["usage" ]["completion_tokens" ] = completion_tokens
2131
2146
return llama_types .CreateChatCompletionResponse (
2132
2147
id = "chat" + completion ["id" ],
2133
2148
object = "chat.completion" ,
@@ -2140,7 +2155,6 @@ def create_completion(stop):
2140
2155
"message" : {
2141
2156
"role" : "assistant" ,
2142
2157
"content" : None if content == "" else content ,
2143
- "tool_calls" : tool_calls ,
2144
2158
** function_call_dict ,
2145
2159
},
2146
2160
"finish_reason" : "tool_calls" if len (tool_calls ) > 0 else "stop" ,
0 commit comments