@@ -1828,27 +1828,35 @@ def prepare_messages_for_inference(
1828
1828
version : Literal ["v1" , "v2" ],
1829
1829
functions : Optional [List [llama_types .ChatCompletionFunctions ]] = None ,
1830
1830
tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
1831
+ tool_choice : Union [Dict , str ] = "auto" ,
1831
1832
):
1832
1833
all_messages : List [llama_types .ChatCompletionRequestMessage ] = []
1833
- if functions is not None :
1834
+ if tool_choice == "none" :
1834
1835
all_messages .append (
1835
1836
llama_types .ChatCompletionRequestSystemMessage (
1836
- role = "system" , content = generate_schema_from_functions (functions )
1837
+ role = "system" , content = generate_schema_from_functions ([] )
1837
1838
)
1838
1839
)
1839
- elif tools is not None :
1840
- all_messages .append (
1841
- llama_types .ChatCompletionRequestSystemMessage (
1842
- role = "system" ,
1843
- content = generate_schema_from_functions (
1844
- [
1845
- tool ["function" ]
1846
- for tool in tools
1847
- if tool ["type" ] == "function"
1848
- ]
1849
- ),
1840
+ else :
1841
+ if functions is not None :
1842
+ all_messages .append (
1843
+ llama_types .ChatCompletionRequestSystemMessage (
1844
+ role = "system" , content = generate_schema_from_functions (functions )
1845
+ )
1846
+ )
1847
+ elif tools is not None and tool_choice != "none" :
1848
+ all_messages .append (
1849
+ llama_types .ChatCompletionRequestSystemMessage (
1850
+ role = "system" ,
1851
+ content = generate_schema_from_functions (
1852
+ [
1853
+ tool ["function" ]
1854
+ for tool in tools
1855
+ if tool ["type" ] == "function"
1856
+ ]
1857
+ ),
1858
+ )
1850
1859
)
1851
- )
1852
1860
1853
1861
all_messages .append (
1854
1862
llama_types .ChatCompletionRequestSystemMessage (
@@ -1888,7 +1896,7 @@ def prepare_messages_for_inference(
1888
1896
function_call = "auto"
1889
1897
1890
1898
prompt = prepare_messages_for_inference (
1891
- messages , tokenizer , version , functions , tools
1899
+ messages , tokenizer , version , functions , tools , function_call
1892
1900
)
1893
1901
1894
1902
# If no tools/functions are provided
@@ -1985,17 +1993,12 @@ def create_completion(stop):
1985
1993
1986
1994
content = ""
1987
1995
function_calls , function_bodies = [], []
1996
+ completion_tokens = 0
1988
1997
1989
1998
if version == "v1" :
1990
1999
# If no or "auto" tool_choice/function_call
1991
2000
if isinstance (function_call , str ) and function_call == "auto" :
1992
2001
stops = ["\n " , END_ASSISTANT_TOKEN ]
1993
- # If tool_choice/function_call is "none"
1994
- elif isinstance (function_call , str ) and function_call == "none" :
1995
- prompt = prepare_messages_for_inference (
1996
- messages , tokenizer , version , [], []
1997
- )
1998
- stops = END_ASSISTANT_TOKEN
1999
2002
# If tool_choice/function_call is provided
2000
2003
elif isinstance (function_call , dict ):
2001
2004
prompt += f"{ START_FUNCTION_CALL_TOKEN } { function_call ['name' ]} :\n "
@@ -2009,12 +2012,15 @@ def create_completion(stop):
2009
2012
2010
2013
completion = create_completion (stop = stops )
2011
2014
completion_text = completion ["choices" ][0 ]["text" ]
2015
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2016
+
2012
2017
2013
2018
# If the generation does not involve a function call
2014
2019
if (
2015
2020
START_FUNCTION_CALL_TOKEN not in prompt
2016
2021
and START_FUNCTION_CALL_TOKEN not in completion_text
2017
2022
):
2023
+ completion ["usage" ]["completion_tokens" ] = completion_tokens
2018
2024
return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
2019
2025
# If the generation involves a function call in completion, generate the parameters
2020
2026
elif (
@@ -2032,30 +2038,22 @@ def create_completion(stop):
2032
2038
)
2033
2039
grammar = get_grammar (function_calls [- 1 ])
2034
2040
completion = create_completion (stop = END_FUNCTION_CALL_TOKEN )
2041
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2035
2042
function_bodies .append (completion ["choices" ][0 ]["text" ].strip ())
2036
2043
# If the prompt involves a function call, just append generated parameters to function_bodies
2037
2044
else :
2038
2045
function_bodies .append (completion_text .strip ())
2039
2046
else :
2040
- # If tool_choice/function_call is "none"
2041
- if isinstance (function_call , str ) and function_call == "none" :
2042
- prompt = (
2043
- prepare_messages_for_inference (messages , tokenizer , version , [], [])
2044
- + "all\n <|content|>"
2045
- )
2046
- stops = [STOP_TOKEN , FROM_TOKEN ]
2047
- completion = create_completion (stop = stops )
2048
- completion ["choices" ][0 ]["text" ] = completion ["choices" ][0 ]["text" ].strip ()
2049
- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
2050
2047
# If tool_choice/function_call is provided
2051
- elif isinstance (function_call , dict ):
2048
+ if isinstance (function_call , dict ):
2052
2049
prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
2053
2050
function_call = function_call ["name" ]
2054
2051
function_calls .append (function_call )
2055
2052
grammar = get_grammar (function_call )
2056
2053
stops = [STOP_TOKEN , FROM_TOKEN ]
2057
2054
completion = create_completion (stop = stops )
2058
2055
completion_text = completion ["choices" ][0 ]["text" ]
2056
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2059
2057
function_bodies .append (completion_text .strip ())
2060
2058
# If "auto" or no tool_choice/function_call
2061
2059
elif isinstance (function_call , str ) and function_call == "auto" :
@@ -2065,6 +2063,7 @@ def create_completion(stop):
2065
2063
stops = CONTENT_TOKEN
2066
2064
completion = create_completion (stop = stops )
2067
2065
completion_text = completion ["choices" ][0 ]["text" ]
2066
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2068
2067
function_name = completion_text .strip ()
2069
2068
if function_name == "all" :
2070
2069
prompt += "all\n <|content|>"
@@ -2077,12 +2076,23 @@ def create_completion(stop):
2077
2076
stops = [RECIPIENT_TOKEN , STOP_TOKEN ]
2078
2077
completion = create_completion (stop = stops )
2079
2078
completion_text = completion ["choices" ][0 ]["text" ]
2079
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2080
2080
if function_name == "all" :
2081
- content += completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " )
2081
+ if completion_text .endswith ("\n <|from|>assistant\n " ):
2082
+ content += completion_text [:- len ("\n <|from|>assistant\n " )]
2083
+ if completion_text .endswith ("\n <|from|> assistant\n " ):
2084
+ content += completion_text [- len ("\n <|from|> assistant\n " )]
2085
+ else :
2086
+ content += completion_text
2082
2087
content = content .lstrip ()
2083
2088
# Check whether the model wants to generate another turn
2084
2089
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text :
2085
- cleaned_completion_text = completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " ).strip ()
2090
+ if completion_text .endswith ("\n <|from|>assistant\n " ):
2091
+ cleaned_completion_text = completion_text [:- len ("\n <|from|>assistant\n " )].strip ()
2092
+ elif completion_text .endswith ("\n <|from|> assistant\n " ):
2093
+ cleaned_completion_text = completion_text [- len ("\n <|from|> assistant\n " )].strip ()
2094
+ else :
2095
+ cleaned_completion_text = completion_text .strip ()
2086
2096
prompt += f"{ cleaned_completion_text } \n <|from|>assistant\n <|recipient|>"
2087
2097
else :
2088
2098
break
@@ -2092,6 +2102,7 @@ def create_completion(stop):
2092
2102
prompt += completion_text .strip ()
2093
2103
grammar = None
2094
2104
completion = create_completion (stop = stops )
2105
+ completion_tokens += completion ["usage" ]["completion_tokens" ]
2095
2106
if "<|from|> assistant" in completion ["choices" ][0 ]["text" ] or "<|from|>assistant" in completion ["choices" ][0 ]["text" ]:
2096
2107
prompt += "\n <|from|>assistant\n <|recipient|>"
2097
2108
else :
@@ -2120,12 +2131,16 @@ def create_completion(stop):
2120
2131
)
2121
2132
2122
2133
# TODO: support stream mode
2123
- function_call_dict : Union [Dict [str , str ], Dict [Literal ["function_call" ], llama_types .ChatCompletionRequestAssistantMessageFunctionCall ]] = {
2124
- "function_call" : {
2125
- "name" : tool_calls [0 ]["function" ]["name" ],
2126
- "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2127
- }
2128
- } if len (tool_calls ) == 1 else {}
2134
+ function_call_dict : Union [Dict [str , str ], Dict [Literal ["function_call" ], llama_types .ChatCompletionRequestAssistantMessageFunctionCall ]] = {}
2135
+ if len (tool_calls ) > 0 :
2136
+ if tools is not None :
2137
+ function_call_dict ["tool_calls" ] = tool_calls
2138
+ else :
2139
+ function_call_dict ["function_call" ] = {
2140
+ "name" : tool_calls [0 ]["function" ]["name" ],
2141
+ "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2142
+ }
2143
+ completion ["usage" ]["completion_tokens" ] = completion_tokens
2129
2144
return llama_types .CreateChatCompletionResponse (
2130
2145
id = "chat" + completion ["id" ],
2131
2146
object = "chat.completion" ,
@@ -2138,7 +2153,6 @@ def create_completion(stop):
2138
2153
"message" : {
2139
2154
"role" : "assistant" ,
2140
2155
"content" : None if content == "" else content ,
2141
- "tool_calls" : tool_calls ,
2142
2156
** function_call_dict ,
2143
2157
},
2144
2158
"finish_reason" : "tool_calls" if len (tool_calls ) > 0 else "stop" ,
0 commit comments