From 164a3b33283f9c7430a110298011777a14c28c8e Mon Sep 17 00:00:00 2001 From: Abdelrhman Mahfouz Date: Fri, 26 Jan 2024 04:41:07 +0200 Subject: [PATCH] Added functionary v2.2 chat handler + updated readme --- README.md | 11 +- llama_cpp/llama_chat_format.py | 444 ++++++++++++++++++++++++++++++++- 2 files changed, 446 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 7813c96b0..75fdb15d4 100644 --- a/README.md +++ b/README.md @@ -221,12 +221,12 @@ Chat completion is available through the [`create_chat_completion`](https://llam The high-level API also provides a simple interface for function calling. Note that the only model that supports full function calling at this time is "functionary". -The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF) +The gguf-converted files for this model can be found here: [functionary-v2.2](https://huggingface.co/meetkai/functionary-medium-v2.2-GGUF) ```python >>> from llama_cpp import Llama ->>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary") +>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary2") >>> llm.create_chat_completion( messages = [ { @@ -260,12 +260,7 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h } } }], - tool_choice=[{ - "type": "function", - "function": { - "name": "UserDetail" - } - }] + tool_choice="auto" ) ``` diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 6c274aa82..fe938cdf2 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -4,7 +4,8 @@ import json import ctypes import dataclasses -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol +import time +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union, Protocol import jinja2 @@ -13,6 +14,8 @@ import llama_cpp.llama_grammar as llama_grammar from ._utils import suppress_stdout_stderr, Singleton +from transformers import AutoTokenizer, logging +from uuid import uuid4 as uuid class LlamaChatCompletionHandler(Protocol): @@ -1463,3 +1466,442 @@ def __call__( ), stream=stream, ) + + +@register_chat_completion_handler("functionary2") +def functionary2_chat_handler( + llama: llama.Llama, + messages: list[llama_types.ChatCompletionRequestMessage], + functions: Optional[list[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[list[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, list[str]]] = [], + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama_types.LogitsProcessorList] = None, + grammar: Optional[llama_types.LlamaGrammar] = None, + **kwargs, # type: ignore +) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: + logging.set_verbosity_error() + tokenizer = AutoTokenizer.from_pretrained( + "meetkai/functionary-7b-v2", legacy=True) + chat_id = str(uuid()) + + def generate_type_definition( + param: dict[str, llama_types.JsonType], indent_level: int, shared_defs + ) -> str: + indent = " " * indent_level + if "$ref" in param: + # Reference to a shared definition + ref_name = param["$ref"].split("/")[ # type: ignore + -1 + ] # Extract the type name from the reference + return ref_name + elif param.get("type") == "array": + items = param.get("items", {}) + item_type = generate_type_definition( + items, indent_level + 1, shared_defs) # type: ignore + return f"Array<{item_type}>" + elif param.get("type") == "object": + properties = param.get("properties", {}) + nested_schema = "{\n" + for nested_param_name, nested_param in properties.items(): # type: ignore + nested_param_type = generate_type_definition( + nested_param, indent_level + 1, shared_defs + ) + nested_schema += ( + f"{indent} {nested_param_name}: {nested_param_type},\n" + ) + nested_schema += indent + "}" + return nested_schema + elif "enum" in param: + # Enum type + return " | ".join([f'"{enum_value}"' + for enum_value in param["enum"]]) # type: ignore + else: + # Simple type + return param.get("type", "any") # type: ignore + + def generate_shared_definitions(shared_defs, indent_level: int) -> str: + indent = " " * indent_level + shared_definitions = "" + for def_name, def_properties in shared_defs.items(): + shared_definitions += f"{indent}type {def_name} = " + if def_properties.get("type") == "object": + shared_definitions += generate_type_definition( + def_properties, indent_level, shared_defs + ) + elif "enum" in def_properties: + # Enum type + shared_definitions += " | ".join( + [f'"{enum_value}"' for enum_value in def_properties["enum"]] + ) + shared_definitions += ";\n" + return shared_definitions + + def generate_schema_from_functions(functions, namespace="functions") -> str: + schema = ( + "// Supported function definitions that should be called when necessary.\n" + ) + schema += f"namespace {namespace} {{\n\n" + + # Generate shared definitions + shared_definitions = {} + for function in functions: + parameters = function.get("parameters", {}) + shared_definitions.update(parameters.get("$defs", {})) + + schema += generate_shared_definitions(shared_definitions, 1) + + for function in functions: + function_name = function["name"] + description = function.get("description", "") + parameters = function.get("parameters", {}) + required_params = parameters.get("required", []) + + schema += f" // {description}\n" + schema += f" type {function_name} = (_: {{\n" + + for param_name, param in parameters.get("properties", {}).items(): + param_description = param.get("description", "") + param_type = generate_type_definition( + param, 2, shared_definitions) + optional_indicator = "" if param_name in required_params else "?" + schema += f" // {param_description}\n" + schema += f" {param_name}{optional_indicator}: {param_type},\n" + schema += " }) => any;\n\n" + + schema += "}} // namespace {}\n".format(namespace) + return schema + + def prepare_messages_for_inference( + messages: list[llama_types.ChatCompletionRequestMessage], + functions: Optional[list[llama_types.ChatCompletionFunctions]] = None, + tools: Optional[list[llama_types.ChatCompletionTool]] = None, + ): + all_messages: list[llama_types.ChatCompletionRequestMessage] = [] + if functions is not None: + all_messages.append( + llama_types.ChatCompletionRequestSystemMessage( + role="system", content=generate_schema_from_functions(functions) + ) + ) + + if tools is not None: + all_messages.append( + llama_types.ChatCompletionRequestSystemMessage( + role="system", + content=generate_schema_from_functions( + [ + tool["function"] + for tool in tools + if tool["type"] == "function" + ] + ), + ) + ) + + all_messages.extend(messages) + + def role_header(role: str, msg: str) -> str: + return f"<|from|>{role}\n<|recipient|>all\n<|content|>{msg}" + + def function_to_str(func_name: str, args: str): + return f"<|from|>assistant\n<|recipient|>{func_name}\n<|content|>{args}" + + def function_ret_to_str(func_name: str, ret: Optional[str] = None): + if ret is None: + return f"<|from|>{func_name}\n<|recipient|>all\n<|content|>{{}}" + else: + return f"<|from|>{func_name}\n<|recipient|>all\n<|content|>{ret}" + + def message_to_str(msg: llama_types.ChatCompletionRequestMessage): + if msg["role"] == "system": + return role_header("system", msg["content"] or "") + + elif msg["role"] == "function": + raise Exception( + "Role \"function\" not supported, use \"tool\"") + elif msg["role"] == "tool": + if msg["content"] is not None: + return function_ret_to_str(msg['tool_call_id'], msg['content'] or "") + else: + return function_ret_to_str(msg['tool_call_id']) + elif msg["role"] == "user": + s = "" + if isinstance(msg["content"], Iterable) and not isinstance(msg["content"], str): + for part in msg["content"]: + if len(s) > 0: + s += "\n" + if isinstance(part, str): + s += part + elif part["type"] == "text": + s += part["text"] + elif part["type"] == "image_url": + raise Exception("Images not supported") + else: + s = msg["content"] or "" + return role_header("user", s) + elif msg["role"] == "assistant": + s = "" + if msg["content"] is not None: + s += role_header("assistant", msg["content"]) + if "tool_calls" in msg: + for call in msg["tool_calls"]: + if len(s) > 0: + s += "\n\n" + s += function_to_str( + call["function"]["name"], call["function"]["arguments"]) + s += "<|stop|>" + return s + else: + raise ValueError(f"Unsupported role: {msg['role']}") + + return "\n\n".join([message_to_str(msg) for msg in all_messages]) + "\n\n<|from|>assistant\n<|recipient|>" + + def deserialize_response_as_stream( + generated_tokens: Iterable[str], + ) -> Iterator[llama_types.ChatCompletionChunk]: + state = 1 + tool_index = 0 + function_call = False + function_name = "" + buff = "" + + yield { + "id": f"chat_{chat_id}", + "object": "chat.completion.chunk", + "created": int(time()), + "model": model or "functionary_v2", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": { + "role": "assistant" + } + } + ] + } + + def deserialize_function_call(name: str, args: str, tool_index: int, finish: bool) -> llama_types.ChatCompletionChunk: + return { + "id": f"chat_{chat_id}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model or "functionary_v2", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls" if finish else None, + "delta": { + "tool_calls": [ + { + "id": name, + "index": tool_index, + "type": "function", + "function": { + "name": name, + "arguments": args + } + } + ] + } + } + ] + } + + for token in generated_tokens: + if token == "<|stop|>": + if function_call: + yield deserialize_function_call(function_name, buff.strip(), tool_index, True) + else: + yield { + "id": f"chat_{chat_id}", + "object": "chat.completion.chunk", + "created": int(time()), + "model": model or "functionary_v2", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls" if tool_index > 0 else "stop", + "delta": {} + } + ] + } + break + if token == "<|from|>": + if function_call: + function_call = False + yield deserialize_function_call(function_name, buff.strip(), tool_index, False) + tool_index += 1 + function_name = "" + buff = "" + state = 0 + continue + elif token == "<|recipient|>": + if buff.strip() != "assistant": + raise Exception(f"Assistant pretending to be {buff}") + buff = "" + state = 1 + continue + elif token == "<|content|>": + if buff.strip() != "all": + function_call = True + function_name = buff.strip() + buff = "" + state = 2 + continue + if state == 2: + if function_call: + buff += token + else: + yield { + "id": f"chat_{chat_id}", + "object": "chat.completion.chunk", + "created": int(time()), + "model": model or "functionary_v2", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": { + "content": token + } + } + ] + } + else: + buff += token + + def deserialize_response( + completion: str, + usage: llama_types.CompletionUsage + ) -> llama_types.ChatCompletion: + txt = completion.replace("<|stop|>", "") + parts = txt.split("<|from|>") + content = "" + tool_calls: Optional[llama_types.ChatCompletionMessageToolCalls] = None + for i, part in enumerate(parts): + sections = part.split("\n") + if i == 0: + sections = ["assistant", + "<|recipient|>" + sections[0], *sections[1:]] + if len(sections) > 3: + sections[2] = "\n".join(sections[2:]) + if sections[0].strip() != "assistant": + raise Exception(f"Assistant is pretending to be {sections[0]}") + if sections[1] == "<|recipient|>all": + content = sections[2][11:].strip() + else: + if tool_calls is None: + tool_calls = [] + name = sections[1][13:].strip() + args = sections[2][11:].strip() + tool_calls.append({ + "id": name, + "type": "function", + "function": { + "name": name, + "arguments": args + } + }) + if tool_calls is None: + msg: llama_types.ChatCompletionResponseMessage = { + "role": "assistant", + "content": content, + } + else: + msg: llama_types.ChatCompletionResponseMessage = { + "role": "assistant", + "content": content if len(content) > 0 else None, + "tool_calls": tool_calls + } + return { + "id": f"chat_{chat_id}", + "object": "chat.completion", + "created": int(time()), + "model": model or "functionary_v2", + "choices": [ + { + "index": 0, + "message": msg, + "finish_reason": "tool_calls" if tool_calls is not None and len(tool_calls) > 0 else "stop", + } + ], + "usage": usage + } + + prompt = prepare_messages_for_inference(messages, None, tools) + encoded_prompt = tokenizer.encode(prompt) + generated_tokens = llama.generate( + tokens=encoded_prompt, + temp=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + logits_processor=logits_processor, + grammar=grammar + ) + tokens_buff = [] + end_tokens = [*tokenizer.encode("<|stop|>")] + + if stop is not None: + if isinstance(stop, Iterable): + if len(stop) > 0: + if not isinstance(stop, str) and any(not isinstance(x, str) for x in stop): + raise Exception( + "Unsupported list type for stops. list of stops can only contain strings") + end_tokens.extend(tokenizer.encode(stop)) + else: + raise Exception( + "Unknown type for stops. stops must be Union[str, list[str], None]") + + usage: llama_types.CompletionUsage = {"completion_tokens": 0, + "prompt_tokens": len(encoded_prompt), "total_tokens": len(encoded_prompt)} + + def stream_token_to_str(generated_tokens: Iterable[int], usage: llama_types.CompletionUsage) -> Iterable[str]: + for token in generated_tokens: + usage["completion_tokens"] += 1 + usage["total_tokens"] += 1 + tokens_buff.append(token) + if len(tokens_buff) == 1: + s = tokenizer.decode(tokens_buff) + yield s + elif len(tokens_buff) > 1: + s = tokenizer.decode(tokens_buff) + if " " in s: + s = " " + tokenizer.decode(tokens_buff[1]) + else: + s = tokenizer.decode(tokens_buff[1]) + yield s + tokens_buff.pop(0) + if token in end_tokens: + break + if stream: + return deserialize_response_as_stream(stream_token_to_str(generated_tokens, usage)) + else: + return deserialize_response("".join(stream_token_to_str(generated_tokens, usage)), usage)