|
1 | 1 | import json |
| 2 | +import sys |
| 3 | +import logging |
| 4 | +import ctypes |
| 5 | +from unittest.mock import MagicMock |
2 | 6 |
|
3 | 7 | import jinja2 |
4 | 8 |
|
5 | | -from llama_cpp import ( |
6 | | - ChatCompletionRequestUserMessage, |
7 | | -) |
| 9 | +# Stub the native C library and dependent modules so tests can run |
| 10 | +# without compiling llama.cpp |
| 11 | +_mock_llama_cpp = MagicMock() |
| 12 | +_mock_llama_cpp.llama_log_callback = lambda f: f # decorator passthrough |
| 13 | +_mock_llama_cpp.llama_log_set = MagicMock() |
| 14 | +sys.modules.setdefault("llama_cpp.llama_cpp", _mock_llama_cpp) |
| 15 | + |
| 16 | +_mock_llama = MagicMock() |
| 17 | +_mock_llama.StoppingCriteriaList = list |
| 18 | +_mock_llama.LogitsProcessorList = list |
| 19 | +_mock_llama.LlamaGrammar = MagicMock |
| 20 | +sys.modules.setdefault("llama_cpp.llama", _mock_llama) |
| 21 | + |
8 | 22 | import llama_cpp.llama_types as llama_types |
9 | 23 | import llama_cpp.llama_chat_format as llama_chat_format |
10 | 24 |
|
11 | | -from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter |
| 25 | +from llama_cpp.llama_chat_format import ( |
| 26 | + hf_tokenizer_config_to_chat_formatter, |
| 27 | + guess_chat_format_from_gguf_metadata, |
| 28 | + DEEPSEEK_R1_CHAT_TEMPLATE, |
| 29 | +) |
| 30 | + |
| 31 | +ChatCompletionRequestUserMessage = llama_types.ChatCompletionRequestUserMessage |
12 | 32 |
|
13 | 33 | def test_mistral_instruct(): |
14 | 34 | chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" |
@@ -87,3 +107,131 @@ def test_hf_tokenizer_config_str_to_chat_formatter(): |
87 | 107 | ) |
88 | 108 |
|
89 | 109 | assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>" "") |
| 110 | + |
| 111 | + |
| 112 | +def test_deepseek_r1_single_turn(): |
| 113 | + """Test DeepSeek R1 format with a single user message.""" |
| 114 | + messages = [ |
| 115 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="Hello"), |
| 116 | + ] |
| 117 | + response = llama_chat_format.format_deepseek_r1(messages=messages) |
| 118 | + |
| 119 | + bos = "<\uff5cbegin\u2581of\u2581sentence\uff5c>" |
| 120 | + eos = "<\uff5cend\u2581of\u2581sentence\uff5c>" |
| 121 | + user_tag = "<\uff5cUser\uff5c>" |
| 122 | + assistant_tag = "<\uff5cAssistant\uff5c>" |
| 123 | + |
| 124 | + expected = f"{bos}{user_tag}Hello{assistant_tag}" |
| 125 | + assert response.prompt == expected |
| 126 | + assert response.stop == eos |
| 127 | + assert response.added_special is True |
| 128 | + |
| 129 | + |
| 130 | +def test_deepseek_r1_with_system_message(): |
| 131 | + """Test DeepSeek R1 format with a system message.""" |
| 132 | + messages = [ |
| 133 | + llama_types.ChatCompletionRequestSystemMessage(role="system", content="You are a helpful assistant."), |
| 134 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="Hi"), |
| 135 | + ] |
| 136 | + response = llama_chat_format.format_deepseek_r1(messages=messages) |
| 137 | + |
| 138 | + bos = "<\uff5cbegin\u2581of\u2581sentence\uff5c>" |
| 139 | + eos = "<\uff5cend\u2581of\u2581sentence\uff5c>" |
| 140 | + user_tag = "<\uff5cUser\uff5c>" |
| 141 | + assistant_tag = "<\uff5cAssistant\uff5c>" |
| 142 | + |
| 143 | + expected = f"{bos}You are a helpful assistant.{user_tag}Hi{assistant_tag}" |
| 144 | + assert response.prompt == expected |
| 145 | + |
| 146 | + |
| 147 | +def test_deepseek_r1_multi_turn(): |
| 148 | + """Test DeepSeek R1 format with multi-turn conversation.""" |
| 149 | + messages = [ |
| 150 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="What is 2+2?"), |
| 151 | + llama_types.ChatCompletionRequestAssistantMessage(role="assistant", content="4"), |
| 152 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="And 3+3?"), |
| 153 | + ] |
| 154 | + response = llama_chat_format.format_deepseek_r1(messages=messages) |
| 155 | + |
| 156 | + bos = "<\uff5cbegin\u2581of\u2581sentence\uff5c>" |
| 157 | + eos = "<\uff5cend\u2581of\u2581sentence\uff5c>" |
| 158 | + user_tag = "<\uff5cUser\uff5c>" |
| 159 | + assistant_tag = "<\uff5cAssistant\uff5c>" |
| 160 | + |
| 161 | + expected = ( |
| 162 | + f"{bos}" |
| 163 | + f"{user_tag}What is 2+2?" |
| 164 | + f"{assistant_tag}4{eos}" |
| 165 | + f"{user_tag}And 3+3?" |
| 166 | + f"{assistant_tag}" |
| 167 | + ) |
| 168 | + assert response.prompt == expected |
| 169 | + |
| 170 | + |
| 171 | +def test_deepseek_r1_think_stripping(): |
| 172 | + """Test that </think> reasoning content is stripped from assistant messages in multi-turn.""" |
| 173 | + messages = [ |
| 174 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="Solve x+1=3"), |
| 175 | + llama_types.ChatCompletionRequestAssistantMessage( |
| 176 | + role="assistant", |
| 177 | + content="<think>Let me solve this step by step. x+1=3, so x=2.</think>x = 2", |
| 178 | + ), |
| 179 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="Are you sure?"), |
| 180 | + ] |
| 181 | + response = llama_chat_format.format_deepseek_r1(messages=messages) |
| 182 | + |
| 183 | + bos = "<\uff5cbegin\u2581of\u2581sentence\uff5c>" |
| 184 | + eos = "<\uff5cend\u2581of\u2581sentence\uff5c>" |
| 185 | + user_tag = "<\uff5cUser\uff5c>" |
| 186 | + assistant_tag = "<\uff5cAssistant\uff5c>" |
| 187 | + |
| 188 | + # The thinking content should be stripped, only "x = 2" remains |
| 189 | + expected = ( |
| 190 | + f"{bos}" |
| 191 | + f"{user_tag}Solve x+1=3" |
| 192 | + f"{assistant_tag}x = 2{eos}" |
| 193 | + f"{user_tag}Are you sure?" |
| 194 | + f"{assistant_tag}" |
| 195 | + ) |
| 196 | + assert response.prompt == expected |
| 197 | + |
| 198 | + |
| 199 | +def test_deepseek_r1_distill_aliases(): |
| 200 | + """Test that distilled model aliases produce the same output as the base format.""" |
| 201 | + messages = [ |
| 202 | + llama_types.ChatCompletionRequestUserMessage(role="user", content="Hello"), |
| 203 | + ] |
| 204 | + base = llama_chat_format.format_deepseek_r1(messages=messages) |
| 205 | + qwen = llama_chat_format.format_deepseek_r1_distill_qwen(messages=messages) |
| 206 | + llama_variant = llama_chat_format.format_deepseek_r1_distill_llama(messages=messages) |
| 207 | + |
| 208 | + assert base.prompt == qwen.prompt |
| 209 | + assert base.prompt == llama_variant.prompt |
| 210 | + assert base.stop == qwen.stop == llama_variant.stop |
| 211 | + assert base.added_special == qwen.added_special == llama_variant.added_special |
| 212 | + |
| 213 | + |
| 214 | +def test_guess_chat_format_deepseek_r1_exact_match(): |
| 215 | + """Test auto-detection via exact template match.""" |
| 216 | + metadata = {"tokenizer.chat_template": DEEPSEEK_R1_CHAT_TEMPLATE} |
| 217 | + assert guess_chat_format_from_gguf_metadata(metadata) == "deepseek-r1" |
| 218 | + |
| 219 | + |
| 220 | +def test_guess_chat_format_deepseek_r1_heuristic(): |
| 221 | + """Test auto-detection via heuristic token presence.""" |
| 222 | + # A template that contains the DeepSeek tokens but isn't an exact match |
| 223 | + fake_template = "some preamble <\uff5cUser\uff5c> stuff <\uff5cAssistant\uff5c> more stuff" |
| 224 | + metadata = {"tokenizer.chat_template": fake_template} |
| 225 | + assert guess_chat_format_from_gguf_metadata(metadata) == "deepseek-r1" |
| 226 | + |
| 227 | + |
| 228 | +def test_guess_chat_format_no_match(): |
| 229 | + """Test that unrecognized templates return None.""" |
| 230 | + metadata = {"tokenizer.chat_template": "some unknown template"} |
| 231 | + assert guess_chat_format_from_gguf_metadata(metadata) is None |
| 232 | + |
| 233 | + |
| 234 | +def test_guess_chat_format_no_template(): |
| 235 | + """Test that missing chat_template returns None.""" |
| 236 | + metadata = {} |
| 237 | + assert guess_chat_format_from_gguf_metadata(metadata) is None |
0 commit comments