From edc99eda95f01f55299096042a34f083d0c5b523 Mon Sep 17 00:00:00 2001 From: Jason Cox Date: Sun, 7 Jan 2024 18:45:48 -0800 Subject: [PATCH] Add mistral chat format --- llama_cpp/llama_chat_format.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0ef7bd4a8..a33f4dda5 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -107,6 +107,23 @@ def _format_llama2( return ret +def _format_mistral( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the mistral style.""" + ret = system_message + " " + for i, (role, message) in enumerate(messages): + if system_message and i == 0: + ret += message + " " + elif system_message and i == 1: + ret += role + " " + message + " " + sep + elif message: + ret += role + " " + message + " " + else: + ret += role + " " + return ret + + def _format_add_colon_single( system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str ) -> str: @@ -423,6 +440,20 @@ def format_llama2( _prompt = _format_llama2(system_message, _messages, " ", "") + "[/INST]" return ChatFormatterResponse(prompt=_prompt) +@register_chat_format("mistral") +def format_mistral( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_template = "[INST] {system_message}\n" + _roles = dict(user="[INST]", assistant="[/INST]") + _messages = _map_roles(messages, _roles) + system_message = _get_system_message(messages) + if system_message: + system_message = _system_template.format(system_message=system_message) + _prompt = _format_mistral(system_message, _messages, "") + "[/INST]" + print(_prompt) + return ChatFormatterResponse(prompt=_prompt) @register_chat_format("alpaca") def format_alpaca(