diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 3d18d904f..b653410d3 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -185,6 +185,23 @@ def _format_llama2( return ret +def _format_mistral( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the mistral style.""" + ret = system_message + " " + for i, (role, message) in enumerate(messages): + if system_message and i == 0: + ret += message + " " + elif system_message and i == 1: + ret += role + " " + message + " " + sep + elif message: + ret += role + " " + message + " " + else: + ret += role + " " + return ret + + def _format_add_colon_single( system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str ) -> str: @@ -530,6 +547,20 @@ def format_llama2( _prompt = _format_llama2(system_message, _messages, " ", "") + "[/INST]" return ChatFormatterResponse(prompt=_prompt) +@register_chat_format("mistral") +def format_mistral( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_template = "[INST] {system_message}\n" + _roles = dict(user="[INST]", assistant="[/INST]") + _messages = _map_roles(messages, _roles) + system_message = _get_system_message(messages) + if system_message: + system_message = _system_template.format(system_message=system_message) + _prompt = _format_mistral(system_message, _messages, "") + "[/INST]" + print(_prompt) + return ChatFormatterResponse(prompt=_prompt) @register_chat_format("alpaca") def format_alpaca(