From 56bdf714cfe8a7839ba20bc62bee3697758da3e1 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 1 Feb 2025 04:13:02 +0000 Subject: [PATCH 01/59] Support break / continue https://jinja.palletsprojects.com/en/stable/extensions/#loop-controls --- include/minja/minja.hpp | 49 ++++++++++++++++++++++++++++++++++++++--- tests/test-syntax.cpp | 10 +++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index f0e80fd..bcb5a08 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; static std::string typeToString(Type t) { switch (t) { @@ -714,6 +714,8 @@ class TemplateToken { case Type::EndFilter: return "endfilter"; case Type::Generation: return "generation"; case Type::EndGeneration: return "endgeneration"; + case Type::Break: return "break"; + case Type::Continue: return "continue"; } return "Unknown"; } @@ -815,6 +817,22 @@ struct CommentTemplateToken : public TemplateToken { CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} }; +enum class LoopControlType { Break, Continue }; + +class LoopControlException : public std::runtime_error { +public: + LoopControlType control_type; + LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {} + LoopControlException(LoopControlType control_type) + : std::runtime_error((std::ostringstream() << (control_type == LoopControlType::Continue ? "continue" : "break") << " outside of a loop").str()), + control_type(control_type) {} +}; + +struct LoopControlTemplateToken : public TemplateToken { + LoopControlType control_type; + LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} +}; + class TemplateNode { Location location_; protected: @@ -825,6 +843,12 @@ class TemplateNode { void render(std::ostringstream & out, const std::shared_ptr & context) const { try { do_render(out, context); + } catch (const LoopControlException & e) { + // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop. + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw LoopControlException(err.str(), e.control_type); } catch (const std::exception & e) { std::ostringstream err; err << e.what(); @@ -897,6 +921,15 @@ class IfNode : public TemplateNode { } }; +class LoopControlNode : public TemplateNode { + LoopControlType control_type_; + public: + LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} + void do_render(std::ostringstream &, const std::shared_ptr &) const override { + throw LoopControlException(control_type_); + } +}; + class ForNode : public TemplateNode { std::vector var_names; std::shared_ptr iterable; @@ -961,7 +994,12 @@ class ForNode : public TemplateNode { loop.set("last", i == (n - 1)); loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); - body->render(out, loop_context); + try { + body->render(out, loop_context); + } catch (const LoopControlException & e) { + if (e.control_type == LoopControlType::Break) break; + if (e.control_type == LoopControlType::Continue) continue; + } } } }; @@ -2159,7 +2197,7 @@ class Parser { static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); @@ -2291,6 +2329,9 @@ class Parser { } else if (keyword == "endfilter") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "break" || keyword == "continue") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue)); } else { throw std::runtime_error("Unexpected block: " + keyword); } @@ -2414,6 +2455,8 @@ class Parser { children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); } else if (dynamic_cast(token.get())) { // Ignore comments + } else if (auto ctrl_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, ctrl_token->control_type)); } else if (dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index ebe5e19..1d85c61 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -73,6 +73,7 @@ TEST(SyntaxTest, SimpleCases) { auto ThrowsWithSubstr = [](const std::string & expected_substr) { return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); }; + EXPECT_EQ( " b", render(R"( {% set _ = 1 %} {% set _ = 2 %}b)", {}, lstrip_trim_blocks)); @@ -486,10 +487,19 @@ TEST(SyntaxTest, SimpleCases) { "", render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {})); + EXPECT_EQ( + "0,1,2,", + render("{% for i in range(10) %}{{ i }},{% if i == 2 %}{% break %}{% endif %}{% endfor %}", {}, {})); + EXPECT_EQ( + "0,2,4,6,8,", + render("{% for i in range(10) %}{% if i % 2 %}{% continue %}{% endif %}{{ i }},{% endfor %}", {}, {})); if (!getenv("USE_JINJA2")) { // TODO: capture stderr from jinja2 and test these. + EXPECT_THAT([]() { render("{% break %}", {}, {}); }, ThrowsWithSubstr("break outside of a loop")); + EXPECT_THAT([]() { render("{% continue %}", {}, {}); }, ThrowsWithSubstr("continue outside of a loop")); + EXPECT_THAT([]() { render("{%- set _ = [].pop() -%}", {}, {}); }, ThrowsWithSubstr("pop from empty list")); EXPECT_THAT([]() { render("{%- set _ = {}.pop() -%}", {}, {}); }, ThrowsWithSubstr("pop")); EXPECT_THAT([]() { render("{%- set _ = {}.pop('foooo') -%}", {}, {}); }, ThrowsWithSubstr("foooo")); From 3ea75b37800e51db60fdef27e8a0a4149b580fb0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 1 Feb 2025 04:45:32 +0000 Subject: [PATCH 02/59] Test CohereForAI/c4ai-command-r7b-12-2024 --- tests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ccc942..4b6952e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -55,6 +55,7 @@ set(MODEL_IDS abacusai/Fewshot-Metamath-OrcaVicuna-Mistral bofenghuang/vigogne-2-70b-chat CohereForAI/c4ai-command-r-plus # Gated + CohereForAI/c4ai-command-r7b-12-2024 # Gated databricks/dbrx-instruct # Gated google/gemma-2-2b-it # Gated google/gemma-7b-it # Gated From 449d541da39c019b914a28f13397a7122bf72972 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 1 Feb 2025 10:42:08 +0000 Subject: [PATCH 03/59] Disable CohereForAI/c4ai-command-r7b-12-2024 test on WIN32 (https://github.com/google/minja/issues/40) --- tests/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4b6952e..96c368a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -55,7 +55,6 @@ set(MODEL_IDS abacusai/Fewshot-Metamath-OrcaVicuna-Mistral bofenghuang/vigogne-2-70b-chat CohereForAI/c4ai-command-r-plus # Gated - CohereForAI/c4ai-command-r7b-12-2024 # Gated databricks/dbrx-instruct # Gated google/gemma-2-2b-it # Gated google/gemma-7b-it # Gated @@ -101,7 +100,8 @@ set(MODEL_IDS if(NOT WIN32) list(APPEND MODEL_IDS - # Needs investigation + # Needs investigation (https://github.com/google/minja/issues/40) + CohereForAI/c4ai-command-r7b-12-2024 # Gated deepseek-ai/deepseek-coder-33b-instruct deepseek-ai/DeepSeek-Coder-V2-Instruct deepseek-ai/DeepSeek-V2.5 From a06265c385ed41522b96a5760434a913cff1b0bf Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 1 Feb 2025 13:15:43 +0000 Subject: [PATCH 04/59] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d6ffc7c..755bcb7 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ Main limitations (non-exhaustive list): ## Roadmap / TODOs - [x] Fix known issues w/ CRLF on Windows -- [ ] Integrate to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/11016 + https://github.com/ggerganov/llama.cpp/pull/9639 +- [x] Integrate to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/11016 + https://github.com/ggerganov/llama.cpp/pull/9639 - Improve fuzzing coverage: - use thirdparty jinja grammar to guide exploration of inputs (or implement prettification of internal ASTs and use them to generate arbitrary values) - fuzz each filter / test From 855edc42378b8b44bd134cb312a1014509ee58fb Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 1 Feb 2025 22:48:20 +0000 Subject: [PATCH 05/59] Avoid ostringstream::str which breaks openEuler build on llama.cpp --- include/minja/minja.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index bcb5a08..e77eb69 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -824,7 +824,7 @@ class LoopControlException : public std::runtime_error { LoopControlType control_type; LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {} LoopControlException(LoopControlType control_type) - : std::runtime_error((std::ostringstream() << (control_type == LoopControlType::Continue ? "continue" : "break") << " outside of a loop").str()), + : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")), control_type(control_type) {} }; From 2ded86cff8c2eec61cc2b9f1023fa23954bfc838 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 12:01:04 +0000 Subject: [PATCH 06/59] Update README.md --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 755bcb7..9b4c312 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ _**This is not an official Google product**_ -Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (such as [llama.cpp](https://github.com/ggerganov/llama.cpp) or [gemma.cpp](https://github.com/google/gemma.cpp)). +Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016) and [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433)). It is **not general purpose**: it includes just what’s needed for actual chat templates (very limited set of filters, tests and language features). Users with different needs should look at third-party alternatives such as [Jinja2Cpp](https://github.com/jinja2cpp/Jinja2Cpp), [Jinja2CppLight](https://github.com/hughperkins/Jinja2CppLight), or [inja](https://github.com/pantor/inja) (none of which we endorse). @@ -110,7 +110,8 @@ Main limitations (non-exhaustive list): ## Roadmap / TODOs -- [x] Fix known issues w/ CRLF on Windows +- [ ] Fix known line difference issues on Windows +- [ ] Propose integration w/ https://github.com/google/gemma.cpp - [x] Integrate to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/11016 + https://github.com/ggerganov/llama.cpp/pull/9639 - Improve fuzzing coverage: - use thirdparty jinja grammar to guide exploration of inputs (or implement prettification of internal ASTs and use them to generate arbitrary values) From 331652e6eda9443be01115181b6f5ece76b0b61d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 12:04:45 +0000 Subject: [PATCH 07/59] Update README.md --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9b4c312..85a3089 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ It is **not general purpose**: it includes just what’s needed for actual chat ## Usage: -This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++11 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json)'s `json.hpp` in your include path. +This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++17 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json)'s `json.hpp` in your include path. -See API in [minja/minja.hpp](./include/minja/minja.hpp) and [minja/chat-template.h](./include/minja/chat-template.hpp) (experimental). +See API in [minja/minja.hpp](./include/minja/minja.hpp) and [minja/chat-template.hpp](./include/minja/chat-template.hpp) (experimental). For raw Jinja templating (see [examples/raw.cpp](./examples/raw.cpp)): @@ -94,10 +94,11 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji - Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}` - `if` / `elif` / `else` / `endif` - `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring +- `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39)) - `set` w/ namespaces & destructuring - `macro` / `endmacro` - `filter` / `endfilter` -- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject`, `tojson`, `trim` +- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim` Main limitations (non-exhaustive list): @@ -111,6 +112,7 @@ Main limitations (non-exhaustive list): ## Roadmap / TODOs - [ ] Fix known line difference issues on Windows +- [ ] Document the various capabilities detectors + backfill strategies used - [ ] Propose integration w/ https://github.com/google/gemma.cpp - [x] Integrate to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/11016 + https://github.com/ggerganov/llama.cpp/pull/9639 - Improve fuzzing coverage: From 4d6481998eb3a4d620fa6eae4498e55faed6af16 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 14:19:11 +0000 Subject: [PATCH 08/59] Tool support backfil: provide automated tool call example --- include/minja/chat-template.hpp | 50 ++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 58e119a..1900950 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -41,6 +41,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -176,6 +177,43 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (json { + {"arg1", "some_value"}, + }).dump()}, + }}, + }, + })}, + }; + const json tools; + auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + } else { + throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); + } + } else { + + } + tool_call_example_ = full.substr(prefix.size()); + } } const std::string & source() const { return source_; } @@ -229,7 +267,17 @@ class chat_template { }; auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + json adjusted_messages; + if (needs_tools_in_system) { + adjusted_messages = add_system(messages, + "\n\n" + "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + } else { + adjusted_messages = messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); From 0026057b987c5b4bb02ffda4480c2a3e94e775e4 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:39:30 +0000 Subject: [PATCH 09/59] Switch naming towards polyfills (and skip tool-related logic if there's no tools) --- include/minja/chat-template.hpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 1900950..862cfeb 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -50,7 +50,7 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* apply_polyfills= */ false); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -226,19 +226,21 @@ class chat_template { const nlohmann::ordered_json & tools, bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool adjust_inputs = true) const + bool apply_polyfills = true) const { json actual_messages; - auto needs_adjustments = adjust_inputs && (false + auto needs_polyfills = apply_polyfills && (false || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments + || (!tools.is_null() && (false + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments + )) || caps_.requires_typed_content ); - if (needs_adjustments) { + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { From 23f4378a560223db2f0cf1399046859a4d39b480 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:40:34 +0000 Subject: [PATCH 10/59] rename test-chat-template -> test-supported-template --- tests/CMakeLists.txt | 8 ++++---- ...template.cpp => test-supported-template.cpp} | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) rename tests/{test-chat-template.cpp => test-supported-template.cpp} (89%) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 96c368a..0d1c40e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,9 +39,9 @@ add_test(NAME test-syntax-jinja2 COMMAND test-syntax) set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}") -add_executable(test-chat-template test-chat-template.cpp) -target_compile_features(test-chat-template PUBLIC cxx_std_17) -target_link_libraries(test-chat-template PRIVATE nlohmann_json::nlohmann_json) +add_executable(test-supported-template test-supported-template.cpp) +target_compile_features(test-supported-template PUBLIC cxx_std_17) +target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) set(MODEL_IDS # List of model IDs to test the chat template of. @@ -133,7 +133,7 @@ foreach(test_case ${CHAT_TEMPLATE_TEST_CASES}) separate_arguments(test_args UNIX_COMMAND "${test_case}") list(GET test_args -1 last_arg) string(REGEX REPLACE "^[^ ]+/([^ /\\]+)\\.[^.]+$" "\\1" test_name "${last_arg}") - add_test(NAME ${test_name} COMMAND $ ${test_args}) + add_test(NAME ${test_name} COMMAND $ ${test_args}) set_tests_properties(${test_name} PROPERTIES SKIP_RETURN_CODE 127) endforeach() diff --git a/tests/test-chat-template.cpp b/tests/test-supported-template.cpp similarity index 89% rename from tests/test-chat-template.cpp rename to tests/test-supported-template.cpp index 6f8bcb6..f23cd8c 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-supported-template.cpp @@ -55,6 +55,14 @@ static std::string read_file(const std::string &path) { return out; } +static void write_file(const std::string &path, const std::string &content) { + std::ofstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + fs.write(content.data(), content.size()); +} + #ifndef _WIN32 static json caps_to_json(const minja::chat_template_caps &caps) { return { @@ -132,7 +140,14 @@ int main(int argc, char *argv[]) { return 1; } - assert_equals(expected, actual); + if (expected != actual) { + if (getenv("WRITE_GOLDENS")) { + write_file(golden_file, actual); + std::cerr << "Updated golden file: " << golden_file << std::endl; + } else { + assert_equals(expected, actual); + } + } // Some unresolved CRLF issues again with the goldens on Windows. #ifndef _WIN32 From 51c5ac5e738f66219dee3eb1517c0431989b802d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:51:32 +0000 Subject: [PATCH 11/59] refactor fetch_templates_and_goldens --- include/minja/chat-template.hpp | 97 ++++--- scripts/fetch_templates_and_goldens.py | 366 +++++++++++++------------ 2 files changed, 248 insertions(+), 215 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 862cfeb..eee28cb 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -178,42 +178,42 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } - if (!caps_.supports_tools) { - const json user_msg { - {"role", "user"}, - {"content", "Hey"}, - }; - const json tool_call_msg { - {"role", "assistant"}, - {"content", nullptr}, - {"tool_calls", json::array({ - { - // TODO: detect if requires numerical id or fixed length == 6 like Nemo - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"name", "tool_name"}, - {"arguments", (json { - {"arg1", "some_value"}, - }).dump()}, - }}, - }, - })}, - }; - const json tools; - auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); - auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); - if (full.find(prefix) != 0) { - if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - } else { - throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); - } - } else { - - } - tool_call_example_ = full.substr(prefix.size()); - } + // if (!caps_.supports_tools) { + // const json user_msg { + // {"role", "user"}, + // {"content", "Hey"}, + // }; + // const json tool_call_msg { + // {"role", "assistant"}, + // {"content", nullptr}, + // {"tool_calls", json::array({ + // { + // // TODO: detect if requires numerical id or fixed length == 6 like Nemo + // {"id", "call_1___"}, + // {"type", "function"}, + // {"function", { + // {"name", "tool_name"}, + // {"arguments", (json { + // {"arg1", "some_value"}, + // }).dump()}, + // }}, + // }, + // })}, + // }; + // const json tools; + // auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + // auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + // if (full.find(prefix) != 0) { + // if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + // prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + // } else { + // throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); + // } + // } else { + + // } + // tool_call_example_ = full.substr(prefix.size()); + // } } const std::string & source() const { return source_; } @@ -232,13 +232,19 @@ class chat_template { auto needs_polyfills = apply_polyfills && (false || !caps_.supports_system_role - || (!tools.is_null() && (false - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - )) + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments || caps_.requires_typed_content + // || !caps_.supports_system_role + // || (!tools.is_null() && (false + // || !caps_.supports_tools + // || !caps_.supports_tool_responses + // || !caps_.supports_tool_calls + // || caps_.requires_object_arguments + // )) + // || caps_.requires_typed_content ); if (needs_polyfills) { actual_messages = json::array(); @@ -272,9 +278,10 @@ class chat_template { json adjusted_messages; if (needs_tools_in_system) { adjusted_messages = add_system(messages, - "\n\n" - "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" - "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + "Available tools: " + tools.dump(2)); + // "\n\n" + // "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + // "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); } else { adjusted_messages = messages; } diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 619b539..3251b10 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -99,120 +99,208 @@ def to_json(self): # "requires_non_null_content": self.requires_non_null_content, "requires_typed_content": self.requires_typed_content, }, indent=2) - -def detect_caps(template_file, template): - basic_extra_context = { - "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>", - } - def try_raw_render(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + + +class chat_template: + + def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + basic_extra_context = { + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + } + try: - out = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) + out = self.template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) # print(out, file=sys.stderr) return out except BaseException as e: # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - - caps = TemplateCaps() - - user_needle = "" - sys_needle = "" - dummy_str_user_msg = {"role": "user", "content": user_needle } - dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} - - caps.requires_typed_content = \ - (user_needle not in try_raw_render([dummy_str_user_msg])) \ - and (user_needle in try_raw_render([dummy_typed_user_msg])) - dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg - - needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} - - caps.supports_system_role = sys_needle in try_raw_render([needle_system_msg, dummy_user_msg]) - - out = try_raw_render([dummy_user_msg], tools=[{ - "name": "some_tool", - "type": "function", - "function": { + + def __init__(self, template, env=None): + if not env: + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2.ext.loopcontrols] + ) + self.env = env + self.template = env.from_string(template) + + caps = TemplateCaps() + + user_needle = "" + sys_needle = "" + dummy_str_user_msg = {"role": "user", "content": user_needle } + dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} + + caps.requires_typed_content = \ + (user_needle not in self.try_raw_render([dummy_str_user_msg])) \ + and (user_needle in self.try_raw_render([dummy_typed_user_msg])) + dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg + + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} + + caps.supports_system_role = sys_needle in self.try_raw_render([needle_system_msg, dummy_user_msg]) + + out = self.try_raw_render([dummy_user_msg], tools=[{ "name": "some_tool", - "description": "Some tool", - "parameters": { - "type": "object", - "properties": { - "arg": { - "type": "string", - "description": "Some arg", + "type": "function", + "function": { + "name": "some_tool", + "description": "Some tool", + "parameters": { + "type": "object", + "properties": { + "arg": { + "type": "string", + "description": "Some arg", + }, }, + "required": ["arg"], }, - "required": ["arg"], }, - }, - }]) - caps.supports_tools = "some_tool" in out - - def make_tool_calls_msg(tool_calls, content=None): - return { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - def make_tool_call(tool_name, arguments): - return { - "id": "call_1___", - "type": "function", - "function": { - "arguments": arguments, - "name": tool_name, + }]) + caps.supports_tools = "some_tool" in out + + def make_tool_calls_msg(tool_calls, content=None): + return { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + def make_tool_call(tool_name, arguments): + return { + "id": "call_1___", + "type": "function", + "function": { + "arguments": arguments, + "name": tool_name, + } } - } - - dummy_args_obj = {"argument_needle": "print('Hello, World!')"} - - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), - ]) - tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), - ]) - tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out - caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments - caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - - empty_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}]) - none_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]) - caps.requires_non_null_content = \ - (user_needle in try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ - and (user_needle not in try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) - - if caps.supports_tool_calls: - dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) - tc1 = make_tool_call("test_tool1", dummy_args) - tc2 = make_tool_call("test_tool2", dummy_args) - out = try_raw_render([ + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} + + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1, tc2]), + make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), ]) - caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out - - out = try_raw_render([ + tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1]), - { - "role": "tool", - "name": "test_tool1", - "content": "Some response!", - "tool_call_id": "call_911_", - } + make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) - caps.supports_tool_responses = "Some response!" in out - caps.supports_tool_call_id = "call_911_" in out - - return caps - + tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out + + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments + caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments + + caps.requires_non_null_content = \ + (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ + and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) + + if caps.supports_tool_calls: + dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) + tc1 = make_tool_call("test_tool1", dummy_args) + tc2 = make_tool_call("test_tool2", dummy_args) + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1, tc2]), + ]) + caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out + + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1]), + { + "role": "tool", + "name": "test_tool1", + "content": "Some response!", + "tool_call_id": "call_911_", + } + ]) + caps.supports_tool_responses = "Some response!" in out + caps.supports_tool_call_id = "call_911_" in out + + self.original_caps = caps + + def needs_polyfills(self, context): + has_tools = context.get('tools') is not None + caps = self.original_caps + return not caps.supports_system_role \ + or (has_tools is not None and (False \ + or not caps.supports_tools \ + or not caps.supports_tool_responses \ + or not caps.supports_tool_calls \ + or caps.requires_object_arguments \ + )) \ + or caps.requires_typed_content + + def apply(self, context): + + caps = self.original_caps + has_tools = 'tools' in context + + if self.needs_polyfills(context): + if has_tools and not caps.supports_tools: + add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") + + for message in context['messages']: + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if caps.requires_object_arguments: + if tool_call.get('type') == 'function': + arguments = tool_call['function']['arguments'] + try: + arguments = json.loads(arguments) + except: + pass + tool_call['function']['arguments'] = arguments + if not caps.supports_tool_calls: + message['content'] = json.dumps({ + "tool_calls": [ + { + "name": tc['function']['name'], + "arguments": json.loads(tc['function']['arguments']), + "id": tc.get('id'), + } + for tc in message['tool_calls'] + ], + "content": None if message.get('content', '') == '' else message['content'], + }, indent=2) + del message['tool_calls'] + if message.get('role') == 'tool' and not caps.supports_tool_responses: + message['role'] = 'user' + message['content'] = json.dumps({ + "tool_response": { + "tool": message['name'], + "content": message['content'], + "tool_call_id": message.get('tool_call_id'), + } + }, indent=2) + del message['name'] + + if caps.requires_typed_content: + for message in context['messages']: + if 'content' in message and isinstance(message['content'], str): + message['content'] = [{"type": "text", "text": message['content']}] + + try: + return self.template.render(**context) + except Exception as e1: + for message in context['messages']: + if message.get("content") is None: + message["content"] = "" + + try: + return self.template.render(**context) + except Exception as e2: + logger.info(f" ERROR: {e2} (after first error: {e1})") + return f"ERROR: {e2}" + + + + async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): if '{% generation %}' in template_src: print('Removing {% generation %} blocks from template', file=sys.stderr) @@ -221,24 +309,19 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c model_name = model_id.replace("/", "-") base_name = f'{model_name}-{variant}' if variant else model_name template_file = join_cmake_path(output_folder, f'{base_name}.jinja') + caps_file = join_cmake_path(output_folder, f'{base_name}.caps.json') async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[jinja2.ext.loopcontrols] - ) - template = env.from_string(template_src) - - env.filters['safe'] = lambda x: x - env.filters['tojson'] = tojson - env.globals['raise_exception'] = raise_exception - env.globals['strftime_now'] = strftime_now + template = chat_template(template_src) + template.env.filters['safe'] = lambda x: x + template.env.filters['tojson'] = tojson + template.env.globals['raise_exception'] = raise_exception + template.env.globals['strftime_now'] = strftime_now - caps = detect_caps(template_file, template) + caps = template.original_caps if not context_files: print(f"{template_file} {caps_file} n/a {template_file}") @@ -252,74 +335,17 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c async with aiofiles.open(context_file, 'r') as f: context = json.loads(await f.read()) - has_tools = 'tools' in context - needs_tools_in_system = has_tools and not caps.supports_tools - - if not caps.supports_tool_calls and has_tools: + if not caps.supports_tool_calls and context.get('tools') is not None: print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue + needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): continue output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') - if needs_tools_in_system: - add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") - - for message in context['messages']: - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if caps.requires_object_arguments: - if tool_call.get('type') == 'function': - arguments = tool_call['function']['arguments'] - try: - arguments = json.loads(arguments) - except: - pass - tool_call['function']['arguments'] = arguments - if not caps.supports_tool_calls: - message['content'] = json.dumps({ - "tool_calls": [ - { - "name": tc['function']['name'], - "arguments": json.loads(tc['function']['arguments']), - "id": tc.get('id'), - } - for tc in message['tool_calls'] - ], - "content": None if message.get('content', '') == '' else message['content'], - }, indent=2) - del message['tool_calls'] - if message.get('role') == 'tool' and not caps.supports_tool_responses: - message['role'] = 'user' - message['content'] = json.dumps({ - "tool_response": { - "tool": message['name'], - "content": message['content'], - "tool_call_id": message.get('tool_call_id'), - } - }, indent=2) - del message['name'] - - if caps.requires_typed_content: - for message in context['messages']: - if 'content' in message and isinstance(message['content'], str): - message['content'] = [{"type": "text", "text": message['content']}] - - try: - output = template.render(**context) - except Exception as e1: - for message in context['messages']: - if message.get("content") is None: - message["content"] = "" - - try: - output = template.render(**context) - except Exception as e2: - logger.info(f" ERROR: {e2} (after first error: {e1})") - output = f"ERROR: {e2}" - + output = template.apply(context) async with aiofiles.open(output_file, 'w') as f: await f.write(output) From 876fecb35b765b25031ba94da682fbe06cc991bf Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:59:51 +0000 Subject: [PATCH 12/59] print test command in format pastable to vscode debugger --- tests/test-supported-template.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index f23cd8c..e35eeb5 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -104,10 +104,8 @@ int main(int argc, char *argv[]) { return 127; } - std::cout << "# Testing template: " << tmpl_file << std::endl - << "# With caps: " << caps_file << std::endl - << "# With context: " << ctx_file << std::endl - << "# Against golden file: " << golden_file << std::endl + std::cout << "# Testing template:\n" + << "# ./build/bin/test-supported-template " << json::array({tmpl_file, caps_file, ctx_file, golden_file}).dump() << std::endl << std::flush; auto ctx = json::parse(read_file(ctx_file)); From 88b1788e6d22e5e2360733e1afc3be372371aea1 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:09:37 +0000 Subject: [PATCH 13/59] Fix multiline comments (. isn't multiline, need [\s\S\r\n]) --- include/minja/minja.hpp | 2 +- tests/test-syntax.cpp | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index e77eb69..0eadae8 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2194,7 +2194,7 @@ class Parser { } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 1d85c61..b6f6d57 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -74,6 +74,10 @@ TEST(SyntaxTest, SimpleCases) { return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); }; + EXPECT_EQ( + "ok", + render("{# Hey\nHo #}{#- Multiline...\nComments! -#}{{ 'ok' }}{# yo #}", {}, {})); + EXPECT_EQ( " b", render(R"( {% set _ = 1 %} {% set _ = 2 %}b)", {}, lstrip_trim_blocks)); From 53fabddc8a3ac7b4a316961b51032908035fa637 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:19:05 +0000 Subject: [PATCH 14/59] Update chat-template.hpp --- include/minja/chat-template.hpp | 74 +++++++++++++++++---------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index eee28cb..e6c8420 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -178,41 +178,45 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } - // if (!caps_.supports_tools) { - // const json user_msg { - // {"role", "user"}, - // {"content", "Hey"}, - // }; - // const json tool_call_msg { - // {"role", "assistant"}, - // {"content", nullptr}, - // {"tool_calls", json::array({ - // { - // // TODO: detect if requires numerical id or fixed length == 6 like Nemo - // {"id", "call_1___"}, - // {"type", "function"}, - // {"function", { - // {"name", "tool_name"}, - // {"arguments", (json { - // {"arg1", "some_value"}, - // }).dump()}, - // }}, - // }, - // })}, - // }; - // const json tools; - // auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); - // auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); - // if (full.find(prefix) != 0) { - // if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - // prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - // } else { - // throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); - // } - // } else { - - // } - // tool_call_example_ = full.substr(prefix.size()); + // try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (json { + {"arg1", "some_value"}, + }).dump()}, + }}, + }, + })}, + }; + const json tools; + auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + // } else { + // throw std::runtime_error("# prefix not found at start of prefix:\n" + prefix + "\n# vs full:\n" + full + "\n#"); + } + } else { + + } + tool_call_example_ = full.substr(prefix.size()); + } + // } catch (const std::exception & e) { + // fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); // } } From fb7a3b31c23ec3c122839c1a17bccf2a65ded36c Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:41:44 +0000 Subject: [PATCH 15/59] fix compilation --- examples/chat-template.cpp | 22 ++++++++++++---------- include/minja/chat-template.hpp | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/chat-template.cpp b/examples/chat-template.cpp index 8161b2b..d1838e7 100644 --- a/examples/chat-template.cpp +++ b/examples/chat-template.cpp @@ -19,14 +19,16 @@ int main() { /* bos_token= */ "<|start|>", /* eos_token= */ "<|end|>" ); - std::cout << tmpl.apply( - json::parse(R"([ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"} - ])"), - json::parse(R"([ - {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} - ])"), - /* add_generation_prompt= */ true, - /* extra_context= */ {}) << std::endl; + + minja::chat_template_inputs inputs; + inputs.messages = json::parse(R"([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"} + ])"); + inputs.add_generation_prompt = true; + inputs.tools = json::parse(R"([ + {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} + ])"); + + std::cout << tmpl.apply(inputs) << std::endl; } diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index b1f8ee3..5e901f1 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -231,7 +231,7 @@ class chat_template { inputs.messages = json::array({user_msg}); inputs.add_generation_prompt = true; auto prefix = apply(inputs); - + inputs.messages.push_back(tool_call_msg); inputs.add_generation_prompt = false; auto full = apply(inputs); From 354e77a9da87caec5a15c7294c200c8cc536aa6b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:58:35 +0000 Subject: [PATCH 16/59] implement strftime_now + respect new opts --- include/minja/chat-template.hpp | 40 ++++++++++++++++++++----------- tests/test-supported-template.cpp | 8 +++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 5e901f1..8989fa3 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -204,7 +204,8 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } - // try { +#if 0 + try { if (!caps_.supports_tools) { const json user_msg { {"role", "user"}, @@ -247,9 +248,10 @@ class chat_template { } tool_call_example_ = full.substr(prefix.size()); } - // } catch (const std::exception & e) { - // fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); - // } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } +#endif } const std::string & source() const { return source_; } @@ -415,18 +417,28 @@ class chat_template { auto context = minja::Context::make(json({ {"messages", actual_messages}, {"add_generation_prompt", inputs.add_generation_prompt}, - {"bos_token", bos_token_}, - {"eos_token", eos_token_}, - // {"strftime_now", Value::callable([=](const std::shared_ptr & context, minja::ArgumentsValue & args) { - // args.expectArgs("strftime_now", {1, 1}, {0, 0}); - // auto format = args.args[0].get(); - // return Value(std::to_string(inputs.now)); - // })}, })); - + if (opts.use_bos_token) { + context->set("bos_token", bos_token_); + } + if (opts.use_eos_token) { + context->set("eos_token", eos_token_); + } + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); + + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); + } if (!inputs.tools.is_null()) { - auto tools_val = minja::Value(inputs.tools); - context->set("tools", tools_val); + context->set("tools", minja::Value(inputs.tools)); } if (!inputs.extra_context.is_null()) { for (auto & kv : inputs.extra_context.items()) { diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 71625d6..a7f69a3 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -18,6 +18,8 @@ #undef NDEBUG #include +#define TEST_DATE (getenv("TEST_DATE") ? getenv("TEST_DATE") : "2024-07-26") + using json = nlohmann::ordered_json; template @@ -128,6 +130,12 @@ int main(int argc, char *argv[]) { inputs.messages = ctx.at("messages"); inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + + std::istringstream ss(TEST_DATE); + std::tm tm = {}; + ss >> std::get_time(&tm, "%Y-%m-%d"); + inputs.now = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + if (ctx.contains("tools")) { inputs.extra_context = json { {"builtin_tools", { From a6a5cba00eea62fe3385a8dd1deeb4fc88f92fa5 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:17:54 +0000 Subject: [PATCH 17/59] more defensive against non arrays in reject/accept + fix test builtin_tools --- include/minja/minja.hpp | 5 +++++ tests/test-supported-template.cpp | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 0eadae8..069358e 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2695,6 +2695,10 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2772,6 +2776,7 @@ inline std::shared_ptr Context::builtins() { auto & items = args.args[0]; if (items.is_null()) return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); auto attr_name = args.args[1].get(); bool has_test = false; diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index a7f69a3..6cbfb7c 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -139,7 +139,7 @@ int main(int argc, char *argv[]) { if (ctx.contains("tools")) { inputs.extra_context = json { {"builtin_tools", { - {"wolfram_alpha", "brave_search"} + json::array({"wolfram_alpha", "brave_search"}) }}, }; } From f74b40d43ce6215dac33d6d341936d0917b43c6e Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:25:56 +0000 Subject: [PATCH 18/59] rename test script --- README.md | 4 ++-- scripts/{run_fuzzing_mode.sh => fuzzing_tests.sh} | 0 scripts/{run_tests.sh => tests.sh} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename scripts/{run_fuzzing_mode.sh => fuzzing_tests.sh} (100%) rename scripts/{run_tests.sh => tests.sh} (100%) diff --git a/README.md b/README.md index 85a3089..d45c45e 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Main limitations (non-exhaustive list): huggingface-cli login ``` -- Build & run tests (shorthand: `./scripts/run_tests.sh`): +- Build & run tests (shorthand: `./scripts/tests.sh`): ```bash rm -fR build && \ @@ -195,7 +195,7 @@ Main limitations (non-exhaustive list): - Build in [fuzzing mode](https://github.com/google/fuzztest/blob/main/doc/quickstart-cmake.md#fuzzing-mode) & run all fuzzing tests (optionally, set a higher `TIMEOUT` as env var): ```bash - ./scripts/run_fuzzing_mode.sh + ./scripts/fuzzing_tests.sh ``` - If your model's template doesn't run fine, please consider the following before [opening a bug](https://github.com/googlestaging/minja/issues/new): diff --git a/scripts/run_fuzzing_mode.sh b/scripts/fuzzing_tests.sh similarity index 100% rename from scripts/run_fuzzing_mode.sh rename to scripts/fuzzing_tests.sh diff --git a/scripts/run_tests.sh b/scripts/tests.sh similarity index 100% rename from scripts/run_tests.sh rename to scripts/tests.sh From a10a911edc4bc58312660e9e9e14147576971f0c Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:26:14 +0000 Subject: [PATCH 19/59] more type defensiveness --- include/minja/chat-template.hpp | 3 +-- include/minja/minja.hpp | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 8989fa3..eed6946 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -442,8 +442,7 @@ class chat_template { } if (!inputs.extra_context.is_null()) { for (auto & kv : inputs.extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); + context->set(kv.key(), minja::Value(kv.value())); } } diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 069358e..c304b5c 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2615,6 +2615,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); std::ostringstream oss; auto first = true; for (size_t i = 0, n = items.size(); i < n; ++i) { From 968ea9f8a3fead0efd3637f164df15468d15c300 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:28:29 +0000 Subject: [PATCH 20/59] fix json typo --- tests/test-supported-template.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 6cbfb7c..a1cf1db 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -138,9 +138,7 @@ int main(int argc, char *argv[]) { if (ctx.contains("tools")) { inputs.extra_context = json { - {"builtin_tools", { - json::array({"wolfram_alpha", "brave_search"}) - }}, + {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, }; } std::string actual; From 16664ae4691e210a1b5055948db089577905953a Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:41:56 +0000 Subject: [PATCH 21/59] Update tools prompt --- include/minja/chat-template.hpp | 55 ++++++++++++-------------- scripts/fetch_templates_and_goldens.py | 4 +- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index eed6946..5483f77 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -204,7 +204,6 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } -#if 0 try { if (!caps_.supports_tools) { const json user_msg { @@ -228,30 +227,33 @@ class chat_template { }, })}, }; - chat_template_inputs inputs; - inputs.messages = json::array({user_msg}); - inputs.add_generation_prompt = true; - auto prefix = apply(inputs); - - inputs.messages.push_back(tool_call_msg); - inputs.add_generation_prompt = false; - auto full = apply(inputs); + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } if (full.find(prefix) != 0) { if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - // } else { - // throw std::runtime_error("# prefix not found at start of prefix:\n" + prefix + "\n# vs full:\n" + full + "\n#"); } - } else { - + } + if (full.find(prefix) != 0) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); } tool_call_example_ = full.substr(prefix.size()); } } catch (const std::exception & e) { fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); } -#endif } const std::string & source() const { return source_; } @@ -265,21 +267,16 @@ class chat_template { { json actual_messages; + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto needs_polyfills = opts.apply_polyfills && (false || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments + || (has_tools && (false + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments + )) || caps_.requires_typed_content - // || !caps_.supports_system_role - // || (!tools.is_null() && (false - // || !caps_.supports_tools - // || !caps_.supports_tool_responses - // || !caps_.supports_tool_calls - // || caps_.requires_object_arguments - // )) - // || caps_.requires_typed_content ); if (needs_polyfills) { actual_messages = json::array(); @@ -313,9 +310,9 @@ class chat_template { json adjusted_messages; if (needs_tools_in_system) { adjusted_messages = add_system(inputs.messages, - "Available tools: " + inputs.tools.dump(2)); // "\n\n" - // "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2)); + // "\n\n" // "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); } else { adjusted_messages = inputs.messages; @@ -459,7 +456,7 @@ class chat_template { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, + {"content", existing_system + "\n\n" + system_prompt}, }; } else { messages_with_system.insert(messages_with_system.begin(), json { diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 3251b10..a139d1e 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -66,7 +66,7 @@ def add_system(messages, system_prompt): existing_system = messages[0]["content"] messages[0] = { "role": "system", - "content": existing_system + "\n" + system_prompt, + "content": existing_system + "\n\n" + system_prompt, } else: messages.insert(0, { @@ -243,7 +243,7 @@ def apply(self, context): if self.needs_polyfills(context): if has_tools and not caps.supports_tools: - add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") + add_system(context['messages'], f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}") for message in context['messages']: if 'tool_calls' in message: From 01bc7f3fed148942eb36a317f09df1b5577fce08 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:01:24 +0000 Subject: [PATCH 22/59] expose all polyfills as options --- include/minja/chat-template.hpp | 73 +++++++++++++++++++++---------- tests/test-supported-template.cpp | 7 ++- 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 5483f77..46b8db6 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -46,6 +46,14 @@ struct chat_template_options { bool use_bos_token = true; bool use_eos_token = true; bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; }; class chat_template { @@ -268,21 +276,43 @@ class chat_template { json actual_messages; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (!message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message["role"] == "tool") { + has_tool_responses = true; + } + if (message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + auto needs_polyfills = opts.apply_polyfills && (false - || !caps_.supports_system_role - || (has_tools && (false - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - )) - || caps_.requires_typed_content + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content ); + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -305,15 +335,12 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !inputs.tools.is_null() && inputs.tools.size() > 0 && !caps_.supports_tools; json adjusted_messages; - if (needs_tools_in_system) { + if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, - // "\n\n" - "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2)); - // "\n\n" - // "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); } else { adjusted_messages = inputs.messages; } @@ -326,7 +353,7 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + if (polyfill_object_arguments || polyfill_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -341,7 +368,7 @@ class chat_template { } } } - if (!caps_.supports_tool_calls) { + if (polyfill_tool_calls) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -368,7 +395,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!caps_.supports_tool_responses && role == "tool") { + if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -385,7 +412,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !caps_.supports_system_role) { + if (!message["content"].is_null() && polyfill_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -404,9 +431,7 @@ class chat_template { } add_message(message); } - if (!caps_.supports_system_role) { - flush_sys(); - } + flush_sys(); } else { actual_messages = inputs.messages; } @@ -426,13 +451,13 @@ class chat_template { context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { args.expectArgs("strftime_now", {1, 1}, {0, 0}); auto format = args.args[0].get(); - + auto time = std::chrono::system_clock::to_time_t(now); auto local_time = *std::localtime(&time); std::ostringstream ss; ss << std::put_time(&local_time, format.c_str()); return ss.str(); - })); + })); } if (!inputs.tools.is_null()) { context->set("tools", minja::Value(inputs.tools)); diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index a1cf1db..96d7cfa 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -141,9 +141,14 @@ int main(int argc, char *argv[]) { {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, }; } + + minja::chat_template_options opts; + // TODO: implement logic for examples in python + opts.polyfill_tool_call_examples = false; + std::string actual; try { - actual = tmpl.apply(inputs); + actual = tmpl.apply(inputs, opts); } catch (const std::exception &e) { std::cerr << "Error applying template: " << e.what() << std::endl; return 1; From a1608667d1316f0726dbae1807ae40a260d1a9e7 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:30:22 +0000 Subject: [PATCH 23/59] Fully align tool call examples backfill test python logic w/ c++ --- include/minja/chat-template.hpp | 9 +-- scripts/fetch_templates_and_goldens.py | 92 +++++++++++++++++++------- tests/test-supported-template.cpp | 6 +- 3 files changed, 74 insertions(+), 33 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 46b8db6..dfd46d7 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -218,6 +218,9 @@ class chat_template { {"role", "user"}, {"content", "Hey"}, }; + const json args { + {"arg1", "some_value"}, + }; const json tool_call_msg { {"role", "assistant"}, {"content", nullptr}, @@ -228,9 +231,7 @@ class chat_template { {"type", "function"}, {"function", { {"name", "tool_name"}, - {"arguments", (json { - {"arg1", "some_value"}, - }).dump()}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, }}, }, })}, @@ -339,7 +340,7 @@ class chat_template { json adjusted_messages; if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, - "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2) + + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); } else { adjusted_messages = inputs.messages; diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index a139d1e..66238bc 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -86,7 +86,7 @@ class TemplateCaps: requires_object_arguments: bool = False requires_non_null_content: bool = False requires_typed_content: bool = False - + def to_json(self): return json.dumps({ "supports_tools": self.supports_tools, @@ -108,7 +108,7 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>", } - + try: out = self.template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) # print(out, file=sys.stderr) @@ -117,7 +117,7 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - def __init__(self, template, env=None): + def __init__(self, template, known_eos_tokens, env=None): if not env: env = jinja2.Environment( trim_blocks=True, @@ -128,21 +128,21 @@ def __init__(self, template, env=None): self.template = env.from_string(template) caps = TemplateCaps() - + user_needle = "" sys_needle = "" dummy_str_user_msg = {"role": "user", "content": user_needle } dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} - + caps.requires_typed_content = \ (user_needle not in self.try_raw_render([dummy_str_user_msg])) \ and (user_needle in self.try_raw_render([dummy_typed_user_msg])) dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg - + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} - + caps.supports_system_role = sys_needle in self.try_raw_render([needle_system_msg, dummy_user_msg]) - + out = self.try_raw_render([dummy_user_msg], tools=[{ "name": "some_tool", "type": "function", @@ -162,7 +162,7 @@ def __init__(self, template, env=None): }, }]) caps.supports_tools = "some_tool" in out - + def make_tool_calls_msg(tool_calls, content=None): return { "role": "assistant", @@ -178,7 +178,7 @@ def make_tool_call(tool_name, arguments): "name": tool_name, } } - + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} out = self.try_raw_render([ @@ -191,10 +191,10 @@ def make_tool_call(tool_name, arguments): make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out - + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - + caps.requires_non_null_content = \ (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) @@ -208,7 +208,7 @@ def make_tool_call(tool_name, arguments): make_tool_calls_msg([tc1, tc2]), ]) caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out - + out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([tc1]), @@ -221,9 +221,42 @@ def make_tool_call(tool_name, arguments): ]) caps.supports_tool_responses = "Some response!" in out caps.supports_tool_call_id = "call_911_" in out - + + self.tool_call_example = None + try: + if not caps.supports_tools: + user_msg = {"role": "user", "content": "Hey"} + args = {"arg1": "some_value"} + tool_call_msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1___", + "type": "function", + "function": { + "name": "tool_name", + "arguments": args if caps.requires_object_arguments else json.dumps(args), + }, + }, + ], + } + prefix = self.try_raw_render([user_msg], add_generation_prompt=True) + full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False) + if not full.startswith(prefix): + for known_eos_token in known_eos_tokens: + prefix = prefix.rstrip() + if prefix.endswith(known_eos_token): + prefix = prefix[:-len(known_eos_token)] + break + if not full.startswith(prefix): + print("Failed to infer a tool call example (possible template bug)", file=sys.stderr) + self.tool_call_example = full[len(prefix):] + except Exception as e: + print(f"Failed to generate tool call example: {e}", file=sys.stderr) + self.original_caps = caps - + def needs_polyfills(self, context): has_tools = context.get('tools') is not None caps = self.original_caps @@ -237,13 +270,15 @@ def needs_polyfills(self, context): or caps.requires_typed_content def apply(self, context): - + caps = self.original_caps has_tools = 'tools' in context if self.needs_polyfills(context): if has_tools and not caps.supports_tools: - add_system(context['messages'], f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}") + add_system(context['messages'], + f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" + + ("\n\nExample tool call syntax:\n\n" + self.tool_call_example if self.tool_call_example is not None else "")) for message in context['messages']: if 'tool_calls' in message: @@ -299,7 +334,7 @@ def apply(self, context): return f"ERROR: {e2}" - + async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): if '{% generation %}' in template_src: @@ -315,21 +350,30 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - template = chat_template(template_src) + known_eos_tokens = [ + "<|END_OF_TURN_TOKEN|>", + "", + "", + "<|im_end|>", + "<|eom_id|>", + "<|eot_id|>", + "<|end▁of▁sentence|>", + ] + + template = chat_template(template_src, known_eos_tokens) template.env.filters['safe'] = lambda x: x template.env.filters['tojson'] = tojson template.env.globals['raise_exception'] = raise_exception template.env.globals['strftime_now'] = strftime_now - caps = template.original_caps - + if not context_files: print(f"{template_file} {caps_file} n/a {template_file}") return async with aiofiles.open(caps_file, 'w') as f: await f.write(caps.to_json()) - + for context_file in context_files: context_name = os.path.basename(context_file).replace(".json", "") async with aiofiles.open(context_file, 'r') as f: @@ -338,7 +382,7 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c if not caps.supports_tool_calls and context.get('tools') is not None: print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue - + needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): continue @@ -362,7 +406,7 @@ async def async_hf_download(repo_id: str, filename: str) -> str: async def process_model(output_folder: str, model_id: str, context_files: list): try: config_str = await async_hf_download(model_id, "tokenizer_config.json") - + try: config = json.loads(config_str) except json.JSONDecodeError: diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 96d7cfa..302ebbd 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -142,13 +142,9 @@ int main(int argc, char *argv[]) { }; } - minja::chat_template_options opts; - // TODO: implement logic for examples in python - opts.polyfill_tool_call_examples = false; - std::string actual; try { - actual = tmpl.apply(inputs, opts); + actual = tmpl.apply(inputs); } catch (const std::exception &e) { std::cerr << "Error applying template: " << e.what() << std::endl; return 1; From e5afc512dd8fd87f19ab204ac7f1b8f7cfab21bf Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:44:10 +0000 Subject: [PATCH 24/59] Add / deprecate old chat_template::apply overload --- include/minja/chat-template.hpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index dfd46d7..2c3d96c 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -270,6 +270,28 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + std::string apply( const chat_template_inputs & inputs, const chat_template_options & opts = chat_template_options()) const From f9969534fce007440e0b7ccf58673091a79cc657 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:58:35 +0000 Subject: [PATCH 25/59] fix crash --- include/minja/chat-template.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 2c3d96c..69ee4e8 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -303,13 +303,13 @@ class chat_template { auto has_tool_responses = false; auto has_string_content = false; for (const auto & message : inputs.messages) { - if (!message["tool_calls"].is_null()) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { has_tool_calls = true; } - if (message["role"] == "tool") { + if (message.contains("role") && message["role"] == "tool") { has_tool_responses = true; } - if (message["content"].is_string()) { + if (message.contains("content") && message["content"].is_string()) { has_string_content = true; } } From 7e7730e22c5fa6cc7f9d9eb2da2a43e6fb264e45 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 00:40:41 +0000 Subject: [PATCH 26/59] mute another win arm64 deprecation --- tests/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0d1c40e..6a34dd4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -41,6 +41,9 @@ set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYT add_executable(test-supported-template test-supported-template.cpp) target_compile_features(test-supported-template PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-supported-template PUBLIC _CRT_SECURE_NO_WARNINGS) +endif() target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) set(MODEL_IDS From e083fc1e2c937f42839f1bde92a5924b014253b4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 01:07:04 +0000 Subject: [PATCH 27/59] mute another win arm64 deprecation --- examples/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 750bdea..7d75913 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -12,5 +12,8 @@ foreach(example add_executable(${example} ${example}.cpp) target_compile_features(${example} PUBLIC cxx_std_17) target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json) + if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(${example} PUBLIC _CRT_SECURE_NO_WARNINGS) + endif() endforeach() From 1c59c3a0ce6d2f5db77ce0a24afd7058db57d2d6 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 05:02:52 +0000 Subject: [PATCH 28/59] Fix eos / bos when disabled in options (needs empty) --- include/minja/chat-template.hpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 69ee4e8..0e88fb3 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -463,12 +463,8 @@ class chat_template { {"messages", actual_messages}, {"add_generation_prompt", inputs.add_generation_prompt}, })); - if (opts.use_bos_token) { - context->set("bos_token", bos_token_); - } - if (opts.use_eos_token) { - context->set("eos_token", eos_token_); - } + context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); + context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); if (opts.define_strftime_now) { auto now = inputs.now; context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { From 611b631b529ed8399be37f3fd4cb5e31bf08104e Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 09:51:43 +0000 Subject: [PATCH 29/59] Add test-polyfills --- tests/CMakeLists.txt | 19 ++- tests/test-polyfills.cpp | 354 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 372 insertions(+), 1 deletion(-) create mode 100644 tests/test-polyfills.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6a34dd4..89e3dc8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,8 +17,25 @@ target_link_libraries(test-syntax PRIVATE gtest_main gmock ) + +add_executable(test-polyfills test-polyfills.cpp) +target_compile_features(test-polyfills PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-polyfills PUBLIC _CRT_SECURE_NO_WARNINGS) + target_compile_options(gtest PRIVATE -Wno-language-extension-token) +endif() +target_link_libraries(test-polyfills PRIVATE + nlohmann_json::nlohmann_json + gtest_main + gmock +) +if (NOT CMAKE_CROSSCOMPILING) + gtest_discover_tests(test-syntax) +endif() + if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) + gtest_discover_tests(test-polyfills) endif() add_executable(test-capabilities test-capabilities.cpp) @@ -54,7 +71,7 @@ set(MODEL_IDS # minja implementation on the same template and context, and compare the output with the golden. # # For Gated models, you'll need to run `huggingface-cli login` (and be granted access) to download their template. - + abacusai/Fewshot-Metamath-OrcaVicuna-Mistral bofenghuang/vigogne-2-70b-chat CohereForAI/c4ai-command-r-plus # Gated diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp new file mode 100644 index 0000000..d0fcbfc --- /dev/null +++ b/tests/test-polyfills.cpp @@ -0,0 +1,354 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#include "minja.hpp" +#include +#include + +#include +#include +#include +#include "chat-template.hpp" + +using namespace minja; + +#define TEMPLATE_CHATML \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + + +#define TEMPLATE_CHATML_NO_SYSTEM \ + "{%- for message in messages -%}\n" \ + " {%- if message.role == 'system' -%}\n" \ + " {{- raise_exception('System role not supported') -}}\n" \ + " {%- endif -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + + +#define TEMPLATE_DUMMY \ + "{%- for tool in tools -%}\n" \ + " {{- 'tool: ' + (tool | tojson(indent=2)) + '\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- for message in messages -%}\n" \ + " {{- 'message: ' + (message | tojson(indent=2)) + '\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- 'message: ' -}}\n" \ + "{%- endif -%}" + + +const json message_user_text { + { "role", "user" }, + { "content", "I need help" }, +}; +const json message_assistant_text { + { "role", "assistant" }, + { "content", "Hello, world!" }, +}; +const json message_system { + { "role", "system" }, + { "content", "I am The System!" }, +}; +const json tool_calls = json::array({{ + { "type", "function" }, + { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, +}}); + +const json message_assistant_call { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, +}; +const json message_assistant_call_id { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "123456789"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } +}; +const json message_assistant_call_idx { + { "role", "assistant"}, + { "content", {}}, + { "tool_plan", "I'm not so sure"}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "0"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } +}; +const json message_tool { + { "role", "tool"}, + { "content", { + {"result", 123}, + }}, +}; + +const auto special_function_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + } + } +})"); + +auto ThrowsWithSubstr = [](const std::string & expected_substr) { + return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); +}; + +static chat_template_options options_no_polyfills() { + chat_template_options opts; + opts.apply_polyfills = false; + opts.polyfill_system_role = false; + opts.polyfill_tools = false; + opts.polyfill_tool_call_examples = false; + opts.polyfill_tool_calls = false; + opts.polyfill_tool_responses = false; + opts.polyfill_object_arguments = false; + opts.polyfill_typed_content = false; + return opts; +}; + +TEST(PolyfillTest, NoPolyFill) { + chat_template tmpl(TEMPLATE_CHATML, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text}); + + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs, options_no_polyfills())); + + inputs.add_generation_prompt = false; + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n", + tmpl.apply(inputs, options_no_polyfills())); + + inputs.messages = json::array({message_user_text, message_assistant_text}); + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n" + "Hello, world!<|im_end|>\n", + tmpl.apply(inputs, options_no_polyfills())); +} + +TEST(PolyfillTest, SystemRoleSupported) { + chat_template chatml(TEMPLATE_CHATML, "<|im_end|>", ""); + chat_template dummy(TEMPLATE_DUMMY, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_system, message_user_text}); + + EXPECT_EQ( + "<|im_start|>system\n" + "I am The System!<|im_end|>\n" + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + chatml.apply(inputs)); + EXPECT_EQ( + "message: {\n" + " \"role\": \"system\",\n" + " \"content\": \"I am The System!\"\n" + "}\n" + "message: {\n" + " \"role\": \"user\",\n" + " \"content\": \"I need help\"\n" + "}\n" + "message: ", + dummy.apply(inputs)); +} + +TEST(PolyfillTest, SystemRolePolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_system, message_user_text}); + + EXPECT_THAT( + [&]() { tmpl.apply(inputs, options_no_polyfills()); }, + ThrowsWithSubstr("System role not supported")); + + EXPECT_EQ( + "<|im_start|>user\n" + "I am The System!\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolCallSupported) { + chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text, message_assistant_call_id}); + + EXPECT_EQ( + "message: {\n" + " \"role\": \"user\",\n" + " \"content\": \"I need help\"\n" + "}\n" + "message: {\n" + " \"role\": \"assistant\",\n" + " \"content\": null,\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " }\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}\n" + "message: ", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolCallPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text, message_assistant_call_id}); + + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n" + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolsPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text}); + inputs.tools = json::array({special_function_tool}); + + EXPECT_EQ( + "<|im_start|>user\n" + "You can call any of the following tools to satisfy the user's requests: [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]\n" + "\n" + "Example tool call syntax:\n" + "\n" + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"tool_name\",\n" + " \"arguments\": {\n" + " \"arg1\": \"some_value\"\n" + " },\n" + " \"id\": \"call_1___\"\n" + " }\n" + " ]\n" + "}<|im_end|>\n" + "\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>user\n{\n" + " \"tool_response\": {\n" + " \"content\": {\n" + " \"result\": 123\n" + " }\n" + " }\n" + "}<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} \ No newline at end of file From a12dc8a9d09fbd1e14c897d5812f8e96f5b93a1d Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 09:56:27 +0000 Subject: [PATCH 30/59] Update test-polyfills.cpp --- tests/test-polyfills.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index d0fcbfc..0b52836 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -328,13 +328,30 @@ TEST(PolyfillTest, ToolsPolyfill) { " \"id\": \"call_1___\"\n" " }\n" " ]\n" - "}<|im_end|>\n" + "}<|im_end|>\n" // TODO: fix this "\n" "I need help<|im_end|>\n" "<|im_start|>assistant\n", tmpl.apply(inputs)); } +TEST(PolyfillTest, ToolSupported) { + chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "message: {\n" + " \"role\": \"tool\",\n" + " \"content\": {\n" + " \"result\": 123\n" + " }\n" + "}\n" + "message: ", + tmpl.apply(inputs)); +} + TEST(PolyfillTest, ToolPolyfill) { chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); From b9c335e493a5ea082770f2aaaa350f9205e802d3 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 10:20:18 +0000 Subject: [PATCH 31/59] Fix tool call example optional final eos elision --- include/minja/chat-template.hpp | 11 +++++------ scripts/fetch_templates_and_goldens.py | 2 +- tests/test-polyfills.cpp | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 0e88fb3..2efb69a 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -249,11 +249,10 @@ class chat_template { inputs.add_generation_prompt = false; full = apply(inputs); } - - if (full.find(prefix) != 0) { - if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - } + auto eos_pos_last = full.rfind(eos_token_); + if (eos_pos_last == prefix.size() - eos_token_.size() || + (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { + full = full.substr(0, eos_pos_last); } if (full.find(prefix) != 0) { fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); @@ -363,7 +362,7 @@ class chat_template { if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + - (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); } else { adjusted_messages = inputs.messages; } diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 66238bc..5637b3e 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -278,7 +278,7 @@ def apply(self, context): if has_tools and not caps.supports_tools: add_system(context['messages'], f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" + - ("\n\nExample tool call syntax:\n\n" + self.tool_call_example if self.tool_call_example is not None else "")) + ("\n\nExample tool call syntax:\n\n" + self.tool_call_example + "\n\n" if self.tool_call_example is not None else "")) for message in context['messages']: if 'tool_calls' in message: diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index 0b52836..9f9d1f7 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -158,7 +158,7 @@ static chat_template_options options_no_polyfills() { }; TEST(PolyfillTest, NoPolyFill) { - chat_template tmpl(TEMPLATE_CHATML, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text}); @@ -185,7 +185,7 @@ TEST(PolyfillTest, NoPolyFill) { } TEST(PolyfillTest, SystemRoleSupported) { - chat_template chatml(TEMPLATE_CHATML, "<|im_end|>", ""); + chat_template chatml(TEMPLATE_CHATML, "", ""); chat_template dummy(TEMPLATE_DUMMY, "", ""); auto inputs = chat_template_inputs(); @@ -212,7 +212,7 @@ TEST(PolyfillTest, SystemRoleSupported) { } TEST(PolyfillTest, SystemRolePolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_system, message_user_text}); @@ -230,7 +230,7 @@ TEST(PolyfillTest, SystemRolePolyfill) { } TEST(PolyfillTest, ToolCallSupported) { - chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_DUMMY, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text, message_assistant_call_id}); @@ -261,7 +261,7 @@ TEST(PolyfillTest, ToolCallSupported) { } TEST(PolyfillTest, ToolCallPolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text, message_assistant_call_id}); @@ -286,14 +286,14 @@ TEST(PolyfillTest, ToolCallPolyfill) { } TEST(PolyfillTest, ToolsPolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML, "", "<|im_end|>"); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text}); inputs.tools = json::array({special_function_tool}); EXPECT_EQ( - "<|im_start|>user\n" + "<|im_start|>system\n" "You can call any of the following tools to satisfy the user's requests: [\n" " {\n" " \"type\": \"function\",\n" @@ -328,15 +328,15 @@ TEST(PolyfillTest, ToolsPolyfill) { " \"id\": \"call_1___\"\n" " }\n" " ]\n" - "}<|im_end|>\n" // TODO: fix this - "\n" + "}\n\n<|im_end|>\n" + "<|im_start|>user\n" "I need help<|im_end|>\n" "<|im_start|>assistant\n", tmpl.apply(inputs)); } TEST(PolyfillTest, ToolSupported) { - chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_DUMMY, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_tool}); @@ -353,7 +353,7 @@ TEST(PolyfillTest, ToolSupported) { } TEST(PolyfillTest, ToolPolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_tool}); From 3f30c0fe4b41b88d75c1820f3a4dca946b20d13a Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Tue, 4 Feb 2025 16:10:46 -0500 Subject: [PATCH 32/59] simplify whitespace matching in regexes --- include/minja/minja.hpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index c304b5c..bf21fe7 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1792,7 +1792,7 @@ class Parser { auto left = parseStringConcat(); if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); - static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); static std::regex not_tok(R"(not\b)"); std::string op_str; while (!(op_str = consumeToken(compare_tok)).empty()) { @@ -2171,7 +2171,7 @@ class Parser { using TemplateTokenIterator = TemplateTokenVector::const_iterator; std::vector parseVarNames() { - static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); std::vector group; if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); @@ -2194,13 +2194,13 @@ class Parser { } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); - static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); - static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); - static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); + static std::regex block_close_regex(R"(\s*([-~])?%\})"); TemplateTokenVector tokens; std::vector group; @@ -2284,7 +2284,7 @@ class Parser { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "set") { - static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); std::string ns; std::vector var_names; @@ -2400,7 +2400,7 @@ class Parser { auto text = text_token->text; if (post_space == SpaceHandling::Strip) { - static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + static std::regex trailing_space_regex(R"(\s+$)"); text = std::regex_replace(text, trailing_space_regex, ""); } else if (options.lstrip_blocks && it != end) { auto i = text.size(); @@ -2410,7 +2410,7 @@ class Parser { } } if (pre_space == SpaceHandling::Strip) { - static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + static std::regex leading_space_regex(R"(^\s+)"); text = std::regex_replace(text, leading_space_regex, ""); } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { if (text.length() > 0 && text[0] == '\n') { From 96ebff4f6659ee35e879638e88b64ab1b31dbbe0 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Tue, 4 Feb 2025 17:02:21 -0500 Subject: [PATCH 33/59] tokenize: disallow zero-length matches --- include/minja/minja.hpp | 5 +++++ tests/test-syntax.cpp | 1 + 2 files changed, 6 insertions(+) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index bf21fe7..c726592 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2336,6 +2336,11 @@ class Parser { throw std::runtime_error("Unexpected block: " + keyword); } } else if (std::regex_search(it, end, match, non_text_open_regex)) { + if (!match.position()) { + if (match[0] != "{#") + throw std::runtime_error("Internal error: Expected a comment"); + throw std::runtime_error("Missing end of comment tag"); + } auto text_end = it + match.position(); text = std::string(it, text_end); it = text_end; diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index b6f6d57..85bf222 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -522,6 +522,7 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% if 1 %}{% else %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter")); + EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag")); } EXPECT_EQ( From e259cda35cb88d21d89911bdfa853ba48428600b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 8 Feb 2025 19:21:33 +0000 Subject: [PATCH 34/59] Add str.capitalize (support BEE-spoke-data/tFINE-900m-instruct-orpo) (#51) * add string.capitalize * test BEE-spoke-data/tFINE-900m-instruct-orpo --- include/minja/minja.hpp | 10 ++++++++++ tests/CMakeLists.txt | 1 + tests/test-syntax.cpp | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index c726592..c58dd66 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) { return s.substr(start, end - start + 1); } +static std::string capitalize(const std::string & s) { + if (s.empty()) return s; + auto result = s; + result[0] = std::toupper(result[0]); + return result; +} + static std::string html_escape(const std::string & s) { std::string result; result.reserve(s.size()); @@ -1462,6 +1469,9 @@ class MethodCallExpr : public Expression { if (method->get_name() == "strip") { vargs.expectArgs("strip method", {0, 0}, {0, 0}); return Value(strip(str)); + } else if (method->get_name() == "capitalize") { + vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); + return Value(capitalize(str)); } else if (method->get_name() == "endswith") { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 89e3dc8..550515d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -78,6 +78,7 @@ set(MODEL_IDS databricks/dbrx-instruct # Gated google/gemma-2-2b-it # Gated google/gemma-7b-it # Gated + BEE-spoke-data/tFINE-900m-instruct-orpo MiniMaxAI/MiniMax-Text-01 indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 mattshumer/Reflection-Llama-3.1-70B diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 85bf222..b4f4673 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -74,6 +74,10 @@ TEST(SyntaxTest, SimpleCases) { return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); }; + EXPECT_EQ( + "Ok", + render("{{ 'ok'.capitalize() }}", {}, {})); + EXPECT_EQ( "ok", render("{# Hey\nHo #}{#- Multiline...\nComments! -#}{{ 'ok' }}{# yo #}", {}, {})); From 7eb5202c5faecc7bbc80f1d4b4cf750eb3485645 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sun, 9 Feb 2025 17:02:41 +0000 Subject: [PATCH 35/59] Fix deepseek r1 tool call example polyfill (template newly adds trailing ) (#52) * Fix deepseek r1 tool call example polyfill (their template newly adds trailing ) * test tool outputs for common templates * tests: align extra context in c++ w/ python + remove python tojson override --- include/minja/chat-template.hpp | 19 ++- scripts/fetch_templates_and_goldens.py | 67 ++++---- tests/CMakeLists.txt | 8 +- tests/contexts/simple.json | 3 +- tests/contexts/system.json | 3 +- tests/contexts/tool_use.json | 57 +++---- tests/test-polyfills.cpp | 214 ++++++++++++++++++++++++- tests/test-supported-template.cpp | 14 +- 8 files changed, 310 insertions(+), 75 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 2efb69a..882ba41 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -254,10 +254,25 @@ class chat_template { (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { full = full.substr(0, eos_pos_last); } - if (full.find(prefix) != 0) { + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; + } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; + } + auto example = full.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; } - tool_call_example_ = full.substr(prefix.size()); } } catch (const std::exception & e) { fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 5637b3e..fd022ba 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -43,9 +43,6 @@ def raise_exception(message: str): raise ValueError(message) -def tojson(eval_ctx, value, indent=None): - return json.dumps(value, indent=indent) - TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26') @@ -114,16 +111,22 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext # print(out, file=sys.stderr) return out except BaseException as e: - # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) + # print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - def __init__(self, template, known_eos_tokens, env=None): + def __init__(self, template, env=None, filters=None, global_functions=None): if not env: env = jinja2.Environment( trim_blocks=True, lstrip_blocks=True, extensions=[jinja2.ext.loopcontrols] ) + if filters: + for name, func in filters.items(): + env.filters[name] = func + if global_functions: + for name, func in global_functions.items(): + env.globals[name] = func self.env = env self.template = env.from_string(template) @@ -243,15 +246,24 @@ def make_tool_call(tool_name, arguments): } prefix = self.try_raw_render([user_msg], add_generation_prompt=True) full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False) - if not full.startswith(prefix): - for known_eos_token in known_eos_tokens: - prefix = prefix.rstrip() - if prefix.endswith(known_eos_token): - prefix = prefix[:-len(known_eos_token)] - break - if not full.startswith(prefix): + + common_prefix_length = 0 + for i in range(min(len(prefix), len(full))): + if prefix[i] != full[i]: + break + if prefix[i] == '<': + # DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + # but it removes thinking tags for past messages. + # The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue + common_prefix_length = i + 1 + + example = full[common_prefix_length:] + if "tool_name" not in example and "some_value" not in example: print("Failed to infer a tool call example (possible template bug)", file=sys.stderr) - self.tool_call_example = full[len(prefix):] + else: + self.tool_call_example = example + except Exception as e: print(f"Failed to generate tool call example: {e}", file=sys.stderr) @@ -321,7 +333,11 @@ def apply(self, context): message['content'] = [{"type": "text", "text": message['content']}] try: - return self.template.render(**context) + out = self.template.render(**context) + out = out.replace("\\u0027", "'") + out = out.replace('"', '"') + out = out.replace(''', "'") + return out except Exception as e1: for message in context['messages']: if message.get("content") is None: @@ -350,21 +366,14 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - known_eos_tokens = [ - "<|END_OF_TURN_TOKEN|>", - "", - "", - "<|im_end|>", - "<|eom_id|>", - "<|eot_id|>", - "<|end▁of▁sentence|>", - ] - - template = chat_template(template_src, known_eos_tokens) - template.env.filters['safe'] = lambda x: x - template.env.filters['tojson'] = tojson - template.env.globals['raise_exception'] = raise_exception - template.env.globals['strftime_now'] = strftime_now + template = chat_template(template_src, + filters={ + 'safe': lambda x: x, + }, + global_functions={ + 'raise_exception': raise_exception, + 'strftime_now': strftime_now, + }) caps = template.original_caps if not context_files: diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 550515d..47b0900 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,11 +31,8 @@ target_link_libraries(test-polyfills PRIVATE ) if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) -endif() - -if (NOT CMAKE_CROSSCOMPILING) - gtest_discover_tests(test-syntax) - gtest_discover_tests(test-polyfills) + add_test(NAME test-polyfills COMMAND test-polyfills) + set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endif() add_executable(test-capabilities test-capabilities.cpp) @@ -82,6 +79,7 @@ set(MODEL_IDS MiniMaxAI/MiniMax-Text-01 indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 mattshumer/Reflection-Llama-3.1-70B + meetkai/functionary-medium-v3.1 meetkai/functionary-medium-v3.2 meta-llama/Llama-3.1-8B-Instruct # Gated meta-llama/Llama-3.2-3B-Instruct # Gated diff --git a/tests/contexts/simple.json b/tests/contexts/simple.json index 560f92f..5e89f22 100644 --- a/tests/contexts/simple.json +++ b/tests/contexts/simple.json @@ -11,5 +11,6 @@ ], "add_generation_prompt": true, "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "tools_in_user_message": false } diff --git a/tests/contexts/system.json b/tests/contexts/system.json index 4d72972..7cbc5c2 100644 --- a/tests/contexts/system.json +++ b/tests/contexts/system.json @@ -15,5 +15,6 @@ ], "add_generation_prompt": true, "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "tools_in_user_message": false } diff --git a/tests/contexts/tool_use.json b/tests/contexts/tool_use.json index 4920d19..cca70cb 100644 --- a/tests/contexts/tool_use.json +++ b/tests/contexts/tool_use.json @@ -88,6 +88,7 @@ "add_generation_prompt": true, "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>", + "tools_in_user_message": false, "builtin_tools": [ "wolfram_alpha", "brave_search" @@ -96,72 +97,72 @@ "todays_date": "2024-09-03", "tools": [ { - "type": "function", "function": { - "name": "ipython", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "name": "ipython", "parameters": { - "type": "object", "properties": { "code": { - "type": "string", - "description": "The code to run in the ipython interpreter." + "description": "The code to run in the ipython interpreter.", + "type": "string" } }, - "required": ["code"] + "required": ["code"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "brave_search", "description": "Executes a web search with Brave.", + "name": "brave_search", "parameters": { - "type": "object", "properties": { "query": { - "type": "string", - "description": "The query to search for." + "description": "The query to search for.", + "type": "string" } }, - "required": ["query"] + "required": ["query"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "wolfram_alpha", "description": "Executes a query with Wolfram Alpha.", + "name": "wolfram_alpha", "parameters": { - "type": "object", "properties": { "query": { - "type": "string", - "description": "The query to execute." + "description": "The query to execute.", + "type": "string" } }, - "required": ["query"] + "required": ["query"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "test", "description": "Runs a test.", + "name": "test", "parameters": { - "type": "object", "properties": { "condition": { - "type": "boolean", - "description": "The condition to test." + "description": "The condition to test.", + "type": "boolean" } }, - "required": ["condition"] + "required": ["condition"], + "type": "object" } - } + }, + "type": "function" } ] } \ No newline at end of file diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index 9f9d1f7..d1c598b 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -17,6 +17,22 @@ using namespace minja; +static std::string read_file(const std::string &path) +{ + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) + { + throw std::runtime_error("Failed to open file: " + path); + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + #define TEMPLATE_CHATML \ "{%- for message in messages -%}\n" \ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ @@ -120,6 +136,7 @@ const json message_tool { { "content", { {"result", 123}, }}, + { "tool_call_id", "123456789"}, }; const auto special_function_tool = json::parse(R"({ @@ -346,7 +363,8 @@ TEST(PolyfillTest, ToolSupported) { " \"role\": \"tool\",\n" " \"content\": {\n" " \"result\": 123\n" - " }\n" + " },\n" + " \"tool_call_id\": \"123456789\"\n" "}\n" "message: ", tmpl.apply(inputs)); @@ -363,9 +381,199 @@ TEST(PolyfillTest, ToolPolyfill) { " \"tool_response\": {\n" " \"content\": {\n" " \"result\": 123\n" - " }\n" + " },\n" + " \"tool_call_id\": \"123456789\"\n" " }\n" "}<|im_end|>\n" "<|im_start|>assistant\n", tmpl.apply(inputs)); -} \ No newline at end of file +} + +#ifndef _WIN32 +TEST(ToolTest, DeepSeekR1) { + chat_template tmpl(read_file("tests/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|tool▁outputs▁begin|><|tool▁output▁begin|>{'result': 123}<|tool▁output▁end|><|tool▁outputs▁end|>", + tmpl.apply(inputs)); +} + +TEST(ToolTest, CommandR7b) { + chat_template tmpl(read_file("tests/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n" + "You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n" + "\n" + "Your information cutoff date is June 2024.\n" + "\n" + "You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n" + "# Default Preamble\n" + "The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n" + "- Your name is Command.\n" + "- You are a large language model built by Cohere.\n" + "- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n" + "- If the input is ambiguous, ask clarifying follow-up questions.\n" + "- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n" + "- Use LaTeX to generate mathematical notation for complex equations.\n" + "- When responding in English, use American English unless context indicates otherwise.\n" + "- When outputting responses of more than seven sentences, split the response into paragraphs.\n" + "- Prefer the active voice.\n" + "- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n" + "- Use gender-neutral pronouns for unspecified persons.\n" + "- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n" + "- Use the third person when asked to write a summary.\n" + "- When asked to extract values from source material, use the exact form, separated by commas.\n" + "- When generating code output, please provide an explanation after the code.\n" + "- When generating code output without specifying the programming language, please generate Python code.\n" + "- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n" + " {\n" + " \"tool_call_id\": \"\",\n" + " \"results\": {\n" + " \"0\": {\"result\": 123}\n" + " },\n" + " \"is_error\": null\n" + " }\n" + "]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + tmpl.apply(inputs)); +} +#endif // NOT _WIN32 + +TEST(ToolTest, MistralNemo) { + chat_template tmpl(read_file("tests/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "[TOOL_RESULTS]{\"content\": {'result': 123}, \"call_id\": \"123456789\"}[/TOOL_RESULTS]", + tmpl.apply(inputs)); +} + +TEST(ToolTest, NousResearchHermes3) { + chat_template tmpl(read_file("tests/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n" + "For each function call return a json object with function name and arguments within XML tags as follows:\n" + "\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n" + "\n" + "{'result': 123}\n" + "<|im_end|><|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, NousResearchHermes2) { + chat_template tmpl(read_file("tests/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n" + "For each function call return a json object with function name and arguments within XML tags as follows:\n" + "\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n" + "\n" + "{'result': 123}\n" + "<|im_end|><|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, Llama3_3) { + chat_template tmpl(read_file("tests/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "Cutting Knowledge Date: December 2023\n" + "Today Date: 26 Jul 2024\n" + "\n" + "<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n" + "\n" + "{\"result\": 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, MeetkaiFunctionary3_1) { + chat_template tmpl(read_file("tests/meetkai-functionary-medium-v3.1.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "\n" + "Cutting Knowledge Date: December 2023\n" + "\n" + "<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n" + "\n" + "{'result': 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, MeetkaiFunctionary3_2) { + chat_template tmpl(read_file("tests/meetkai-functionary-medium-v3.2.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "You are capable of executing available function(s) if required.\n" + "Only execute function(s) when absolutely necessary.\n" + "Ask for the required input to:recipient==all\n" + "Use JSON for function arguments.\n" + "Respond in this format:\n" + ">>>${recipient}\n" + "${content}\n" + "Available functions:\n" + "// Supported function definitions that should be called when necessary.\n" + "namespace functions {\n" + "\n" + "} // namespace functions<|eot_id|><|start_header_id|>tool<|end_header_id|>\n" + "\n" + "{'result': 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n" + ">>>", + tmpl.apply(inputs)); +} + +/* +https://github.com/google/minja/issues/7 +TEST(ToolTest, FirefunctionV2) { + chat_template tmpl(read_file("tests/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>tool\n" + "{\n" + " \"result\": 123\n" + "}\n" + "<|im_end|>", + tmpl.apply(inputs)); +} +*/ diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 302ebbd..965375f 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -128,19 +128,21 @@ int main(int argc, char *argv[]) { struct minja::chat_template_inputs inputs; inputs.messages = ctx.at("messages"); - inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); + ctx.erase("messages"); + + if (ctx.contains("tools")) { + inputs.tools = ctx.at("tools"); + ctx.erase("tools"); + } inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + ctx.erase("add_generation_prompt"); std::istringstream ss(TEST_DATE); std::tm tm = {}; ss >> std::get_time(&tm, "%Y-%m-%d"); inputs.now = std::chrono::system_clock::from_time_t(std::mktime(&tm)); - if (ctx.contains("tools")) { - inputs.extra_context = json { - {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, - }; - } + inputs.extra_context = ctx; std::string actual; try { From ff6f9a0fac798cc172883507189a5e152b00f0bf Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sun, 9 Feb 2025 17:27:31 +0000 Subject: [PATCH 36/59] Refactor test infra (#53) * Refactor async code of test goldens fetcher * ci: checkout single branch --- .github/workflows/build.yml | 3 +- scripts/fetch_templates_and_goldens.py | 65 ++++++++++++++------------ tests/CMakeLists.txt | 4 +- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d73d9b9..5f7fbdb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -52,7 +52,8 @@ jobs: - name: Clone uses: actions/checkout@v4 with: - fetch-depth: 0 + fetch-depth: 1 + single-branch: true - name: ccache uses: hendrikmuhs/ccache-action@v1.2.11 diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index fd022ba..6e65099 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -281,7 +281,9 @@ def needs_polyfills(self, context): )) \ or caps.requires_typed_content - def apply(self, context): + def apply(self, context: dict): + assert isinstance(context, dict) + context = json.loads(json.dumps(context)) caps = self.original_caps has_tools = 'tools' in context @@ -349,10 +351,14 @@ def apply(self, context): logger.info(f" ERROR: {e2} (after first error: {e1})") return f"ERROR: {e2}" +@dataclass +class Context: + name: str + file: str + bindings: dict - -async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): +async def handle_chat_template(output_folder, model_id, variant, template_src, contexts: list[Context]): if '{% generation %}' in template_src: print('Removing {% generation %} blocks from template', file=sys.stderr) template_src = template_src.replace('{% generation %}', '').replace('{% endgeneration %}', '') @@ -376,33 +382,32 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c }) caps = template.original_caps - if not context_files: + if not contexts: print(f"{template_file} {caps_file} n/a {template_file}") return async with aiofiles.open(caps_file, 'w') as f: await f.write(caps.to_json()) - for context_file in context_files: - context_name = os.path.basename(context_file).replace(".json", "") - async with aiofiles.open(context_file, 'r') as f: - context = json.loads(await f.read()) - - if not caps.supports_tool_calls and context.get('tools') is not None: - print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) + assert isinstance(contexts, list) + for context in contexts: + assert isinstance(context, Context) + assert isinstance(context.bindings, dict) + if not caps.supports_tool_calls and context.bindings.get('tools') is not None: + print(f'Skipping {context.name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue - needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools - if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): + needs_tools_in_system = len(context.bindings.get('tools', [])) > 0 and not caps.supports_tools + if not caps.supports_system_role and (any(m['role'] == 'system' for m in context.bindings['messages']) or needs_tools_in_system): continue - output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') + output_file = join_cmake_path(output_folder, f'{base_name}-{context.name}.txt') - output = template.apply(context) + output = template.apply(context.bindings) async with aiofiles.open(output_file, 'w') as f: await f.write(output) - print(f"{template_file} {caps_file} {context_file} {output_file}") + print(f"{template_file} {caps_file} {context.file} {output_file}") async def async_hf_download(repo_id: str, filename: str) -> str: headers = build_hf_headers() @@ -412,8 +417,9 @@ async def async_hf_download(repo_id: str, filename: str) -> str: response.raise_for_status() return await response.text() -async def process_model(output_folder: str, model_id: str, context_files: list): +async def process_model(output_folder: str, model_id: str, contexts: list[Context]): try: + print(f"Processing model {model_id}...", file=sys.stderr) config_str = await async_hf_download(model_id, "tokenizer_config.json") try: @@ -424,14 +430,16 @@ async def process_model(output_folder: str, model_id: str, context_files: list): assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json!' chat_template = config['chat_template'] if isinstance(chat_template, str): - await handle_chat_template(output_folder, model_id, None, chat_template, context_files) + await handle_chat_template(output_folder, model_id, None, chat_template, contexts) else: await asyncio.gather(*[ - handle_chat_template(output_folder, model_id, ct['name'], ct['template'], context_files) + handle_chat_template(output_folder, model_id, ct['name'], ct['template'], contexts) for ct in chat_template ]) except Exception as e: logger.error(f"Error processing model {model_id}: {e}") + # import traceback + # traceback.print_exc() await handle_chat_template(output_folder, model_id, None, str(e), []) async def async_copy_file(src: str, dst: str): @@ -445,11 +453,15 @@ async def main(): parser.add_argument("json_context_files_or_model_ids", nargs="+", help="List of context JSON files or HuggingFace model IDs") args = parser.parse_args() - context_files = [] + contexts: list[Context] = [] model_ids = [] for file in args.json_context_files_or_model_ids: if file.endswith('.json'): - context_files.append(file) + async with aiofiles.open(file, 'r') as f: + contexts.append(Context( + name=os.path.basename(file).replace(".json", ""), + file=file, + bindings=json.loads(await f.read()))) else: model_ids.append(file) @@ -457,15 +469,10 @@ async def main(): if not os.path.isdir(output_folder): os.makedirs(output_folder) - # Copy context files to the output folder asynchronously - await asyncio.gather(*[ - async_copy_file(context_file, os.path.join(output_folder, os.path.basename(context_file))) - for context_file in context_files - ]) - - # Process models concurrently + # for model_id in model_ids: + # await process_model(output_folder, model_id, contexts) await asyncio.gather(*[ - process_model(output_folder, model_id, context_files) + process_model(output_folder, model_id, contexts) for model_id in model_ids ]) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 47b0900..40d4793 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -152,8 +152,8 @@ foreach(test_case ${CHAT_TEMPLATE_TEST_CASES}) separate_arguments(test_args UNIX_COMMAND "${test_case}") list(GET test_args -1 last_arg) string(REGEX REPLACE "^[^ ]+/([^ /\\]+)\\.[^.]+$" "\\1" test_name "${last_arg}") - add_test(NAME ${test_name} COMMAND $ ${test_args}) - set_tests_properties(${test_name} PROPERTIES SKIP_RETURN_CODE 127) + add_test(NAME test-supported-template-${test_name} COMMAND $ ${test_args}) + set_tests_properties(test-supported-template-${test_name} PROPERTIES SKIP_RETURN_CODE 127) endforeach() if (MINJA_FUZZTEST_ENABLED) From a72057e5190de2c612d4598bb10b4bfd0f53011f Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sun, 9 Feb 2025 17:57:08 +0000 Subject: [PATCH 37/59] Test (~180) more models (#54) * dump ~180 more of most popular chat models * unskip all but 1 test on win32 --- tests/CMakeLists.txt | 240 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 214 insertions(+), 26 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 40d4793..842ae62 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -60,6 +60,9 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ar endif() target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) +# https://huggingface.co/models?other=conversational +# https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/?types=fine-tuned%2Cchat + set(MODEL_IDS # List of model IDs to test the chat template of. # For each of them, the tokenizer_config.json file will be fetched, and the template @@ -70,63 +73,248 @@ set(MODEL_IDS # For Gated models, you'll need to run `huggingface-cli login` (and be granted access) to download their template. abacusai/Fewshot-Metamath-OrcaVicuna-Mistral - bofenghuang/vigogne-2-70b-chat - CohereForAI/c4ai-command-r-plus # Gated - databricks/dbrx-instruct # Gated - google/gemma-2-2b-it # Gated - google/gemma-7b-it # Gated + allenai/Llama-3.1-Tulu-3-405B + allenai/Llama-3.1-Tulu-3-405B-SFT + allenai/Llama-3.1-Tulu-3-8B + arcee-ai/Virtuoso-Lite + arcee-ai/Virtuoso-Medium-v2 + arcee-ai/Virtuoso-Small-v2 + AtlaAI/Selene-1-Mini-Llama-3.1-8B + avemio/GRAG-NEMO-12B-ORPO-HESSIAN-AI BEE-spoke-data/tFINE-900m-instruct-orpo - MiniMaxAI/MiniMax-Text-01 + bespokelabs/Bespoke-Stratos-7B + bfuzzy1/acheron-m1a-llama + bofenghuang/vigogne-2-70b-chat + bytedance-research/UI-TARS-72B-DPO + bytedance-research/UI-TARS-7B-DPO + bytedance-research/UI-TARS-7B-SFT + carsenk/phi3.5_mini_exp_825_uncensored + CohereForAI/aya-expanse-8b + CohereForAI/c4ai-command-r-plus + CohereForAI/c4ai-command-r7b-12-2024 + cyberagent/DeepSeek-R1-Distill-Qwen-14B-Japanese + cyberagent/DeepSeek-R1-Distill-Qwen-32B-Japanese + databricks/dbrx-instruct + DavieLion/Llama-3.2-1B-SPIN-iter3 + deepseek-ai/deepseek-coder-33b-instruct + deepseek-ai/deepseek-coder-6.7b-instruct + deepseek-ai/deepseek-coder-7b-instruct-v1.5 + deepseek-ai/DeepSeek-Coder-V2-Instruct + deepseek-ai/DeepSeek-Coder-V2-Lite-Base + deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct + deepseek-ai/deepseek-llm-67b-chat + deepseek-ai/deepseek-llm-7b-chat + deepseek-ai/DeepSeek-R1-Distill-Llama-70B + deepseek-ai/DeepSeek-R1-Distill-Llama-8B + deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + deepseek-ai/DeepSeek-R1-Distill-Qwen-14B + deepseek-ai/DeepSeek-R1-Distill-Qwen-32B + deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + deepseek-ai/DeepSeek-V2-Lite + deepseek-ai/DeepSeek-V2.5 + deepseek-ai/DeepSeek-V3 + Delta-Vector/Rei-12B + dicta-il/dictalm2.0-instruct + ehristoforu/Falcon3-8B-Franken-Basestruct + EpistemeAI/Mistral-Nemo-Instruct-12B-Philosophy-Math + FlofloB/83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit + FlofloB/test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit + godlikehhd/alpaca_data_sampled_ifd_new_5200 + godlikehhd/alpaca_data_score_max_0.7_2600 + google/gemma-2-27b-it + google/gemma-2-2b-it + google/gemma-2-2b-jpn-it + google/gemma-7b-it + HelpingAI/HAI-SER + HuggingFaceTB/SmolLM2-1.7B-Instruct + HuggingFaceTB/SmolLM2-135M-Instruct + HuggingFaceTB/SmolLM2-360M-Instruct + huihui-ai/DeepSeek-R1-Distill-Llama-70B-abliterated + huihui-ai/DeepSeek-R1-Distill-Llama-8B-abliterated + huihui-ai/DeepSeek-R1-Distill-Qwen-14B-abliterated-v2 + huihui-ai/DeepSeek-R1-Distill-Qwen-32B-abliterated + huihui-ai/DeepSeek-R1-Distill-Qwen-7B-abliterated-v2 + huihui-ai/Qwen2.5-14B-Instruct-1M-abliterated + ibm-granite/granite-3.1-8b-instruct + Ihor/Text2Graph-R1-Qwen2.5-0.5b indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 + Infinigence/Megrez-3B-Instruct + inflatebot/MN-12B-Mag-Mell-R1 + INSAIT-Institute/BgGPT-Gemma-2-27B-IT-v1.0 + jinaai/ReaderLM-v2 + Josephgflowers/TinyLlama_v1.1_math_code-world-test-1 + kms7530/chemeng_qwen-math-7b_24_1_100_1_nonmath + knifeayumu/Cydonia-v1.3-Magnum-v4-22B + langgptai/qwen1.5-7b-chat-sa-v0.1 + LatitudeGames/Wayfarer-12B + LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct + LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct + lightblue/DeepSeek-R1-Distill-Qwen-7B-Japanese + Magpie-Align/Llama-3-8B-Magpie-Align-v0.1 + Magpie-Align/Llama-3.1-8B-Magpie-Align-v0.1 mattshumer/Reflection-Llama-3.1-70B + MaziyarPanahi/calme-3.2-instruct-78b meetkai/functionary-medium-v3.1 meetkai/functionary-medium-v3.2 - meta-llama/Llama-3.1-8B-Instruct # Gated - meta-llama/Llama-3.2-3B-Instruct # Gated - meta-llama/Llama-3.3-70B-Instruct # Gated - meta-llama/Meta-Llama-3.1-8B-Instruct # Gated + meta-llama/Llama-2-7b-chat-hf + meta-llama/Llama-3.1-8B-Instruct + meta-llama/Llama-3.2-1B-Instruct + meta-llama/Llama-3.2-3B-Instruct + meta-llama/Llama-3.3-70B-Instruct + meta-llama/Meta-Llama-3-8B-Instruct + meta-llama/Meta-Llama-3.1-8B-Instruct microsoft/Phi-3-medium-4k-instruct microsoft/Phi-3-mini-4k-instruct microsoft/Phi-3-small-8k-instruct microsoft/Phi-3.5-mini-instruct microsoft/Phi-3.5-vision-instruct - mistralai/Mistral-7B-Instruct-v0.2 # Gated - mistralai/Mistral-Large-Instruct-2407 # Gated - mistralai/Mistral-Large-Instruct-2411 # Gated - mistralai/Mistral-Nemo-Instruct-2407 # Gated - mistralai/Mixtral-8x7B-Instruct-v0.1 # Gated + microsoft/phi-4 + migtissera/Tess-3-Mistral-Nemo-12B + MiniMaxAI/MiniMax-Text-01 + MiniMaxAI/MiniMax-VL-01 + ministral/Ministral-3b-instruct + mistralai/Codestral-22B-v0.1 + mistralai/Mistral-7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + mistralai/Mistral-7B-Instruct-v0.3 + mistralai/Mistral-Large-Instruct-2407 + mistralai/Mistral-Large-Instruct-2411 + mistralai/Mistral-Nemo-Instruct-2407 + mistralai/Mistral-Small-24B-Instruct-2501 + mistralai/Mixtral-8x7B-Instruct-v0.1 + mkurman/Qwen2.5-14B-DeepSeek-R1-1M mlabonne/AlphaMonarch-7B + mlx-community/Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32 + mlx-community/Qwen2.5-VL-7B-Instruct-8bit + mobiuslabsgmbh/DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1 + NaniDAO/deepseek-r1-qwen-2.5-32B-ablated + netcat420/MFANNv0.20 + netcat420/MFANNv0.24 + netease-youdao/Confucius-o1-14B NexaAIDev/Octopus-v2 NousResearch/Hermes-2-Pro-Llama-3-8B NousResearch/Hermes-2-Pro-Mistral-7B NousResearch/Hermes-3-Llama-3.1-70B + NovaSky-AI/Sky-T1-32B-Flash + NovaSky-AI/Sky-T1-32B-Preview + nvidia/AceMath-7B-RM + nvidia/Eagle2-1B + nvidia/Eagle2-9B nvidia/Llama-3.1-Nemotron-70B-Instruct-HF + OnlyCheeini/greesychat-turbo + onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX + open-thoughts/OpenThinker-7B openchat/openchat-3.5-0106 + Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2 OrionStarAI/Orion-14B-Chat + pankajmathur/orca_mini_v6_8b + PowerInfer/SmallThinker-3B-Preview + PrimeIntellect/INTELLECT-1-Instruct + princeton-nlp/Mistral-7B-Base-SFT-RDPO + princeton-nlp/Mistral-7B-Instruct-DPO + princeton-nlp/Mistral-7B-Instruct-RDPO + prithivMLmods/Bellatrix-Tiny-1.5B-R1 + prithivMLmods/Bellatrix-Tiny-1B-R1 + prithivMLmods/Bellatrix-Tiny-1B-v3 + prithivMLmods/Bellatrix-Tiny-3B-R1 + prithivMLmods/Blaze-14B-xElite + prithivMLmods/Calcium-Opus-14B-Elite2-R1 + prithivMLmods/Calme-Ties-78B + prithivMLmods/Calme-Ties2-78B + prithivMLmods/Calme-Ties3-78B + prithivMLmods/ChemQwen2-vL + prithivMLmods/GWQ2b + prithivMLmods/LatexMind-2B-Codec + prithivMLmods/Llama-3.2-6B-AlgoCode + prithivMLmods/Megatron-Opus-14B-Exp + prithivMLmods/Megatron-Opus-14B-Stock + prithivMLmods/Megatron-Opus-7B-Exp + prithivMLmods/Omni-Reasoner-Merged + prithivMLmods/Omni-Reasoner4-Merged + prithivMLmods/Primal-Opus-14B-Optimus-v1 + prithivMLmods/Qwen-7B-Distill-Reasoner + prithivMLmods/Qwen2.5-1.5B-DeepSeek-R1-Instruct + prithivMLmods/Qwen2.5-14B-DeepSeek-R1-1M + prithivMLmods/Qwen2.5-32B-DeepSeek-R1-Instruct + prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M + prithivMLmods/QwQ-Math-IO-500M + prithivMLmods/Triangulum-v2-10B + qingy2024/Falcon3-2x10B-MoE-Instruct + Qwen/QVQ-72B-Preview + Qwen/Qwen1.5-7B-Chat Qwen/Qwen2-7B-Instruct + Qwen/Qwen2-VL-72B-Instruct Qwen/Qwen2-VL-7B-Instruct + Qwen/Qwen2.5-0.5B + Qwen/Qwen2.5-1.5B-Instruct + Qwen/Qwen2.5-14B + Qwen/Qwen2.5-14B-Instruct-1M + Qwen/Qwen2.5-32B + Qwen/Qwen2.5-32B-Instruct + Qwen/Qwen2.5-3B-Instruct + Qwen/Qwen2.5-72B-Instruct + Qwen/Qwen2.5-7B Qwen/Qwen2.5-7B-Instruct + Qwen/Qwen2.5-7B-Instruct-1M + Qwen/Qwen2.5-Coder-32B-Instruct + Qwen/Qwen2.5-Coder-7B-Instruct + Qwen/Qwen2.5-Math-1.5B Qwen/Qwen2.5-Math-7B-Instruct + Qwen/Qwen2.5-VL-3B-Instruct + Qwen/Qwen2.5-VL-72B-Instruct + Qwen/Qwen2.5-VL-7B-Instruct Qwen/QwQ-32B-Preview + rubenroy/Zurich-14B-GCv2-5m + rubenroy/Zurich-7B-GCv2-5m + RWKV-Red-Team/ARWKV-7B-Preview-0.1 + SakanaAI/TinySwallow-1.5B + SakanaAI/TinySwallow-1.5B-Instruct + Sao10K/70B-L3.3-Cirrus-x1 + SentientAGI/Dobby-Mini-Leashed-Llama-3.1-8B + SentientAGI/Dobby-Mini-Unhinged-Llama-3.1-8B + silma-ai/SILMA-Kashif-2B-Instruct-v1.0 + simplescaling/s1-32B + sometimesanotion/Lamarck-14B-v0.7 + sonthenguyen/zephyr-sft-bnb-4bit-DPO-mtbr-180steps + Steelskull/L3.3-Damascus-R1 + Steelskull/L3.3-MS-Nevoria-70b + Steelskull/L3.3-Nevoria-R1-70b + sthenno/tempesthenno-icy-0130 + sumink/qwft + Tarek07/Progenitor-V1.1-LLaMa-70B teknium/OpenHermes-2.5-Mistral-7B TheBloke/FusionNet_34Bx2_MoE-AWQ + thirdeyeai/elevate360m + THUDM/glm-4-9b-chat + THUDM/glm-edge-1.5b-chat + tiiuae/Falcon3-10B-Instruct + TinyLlama/TinyLlama-1.1B-Chat-v1.0 + UCLA-AGI/Mistral7B-PairRM-SPPO-Iter3 + unsloth/DeepSeek-R1-Distill-Llama-8B + unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit + unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit + upstage/solar-pro-preview-instruct + ValiantLabs/Llama3.1-8B-Enigma + xwen-team/Xwen-72B-Chat + xwen-team/Xwen-7B-Chat # Broken, TODO: - # meetkai/functionary-medium-v3.1 # jinja2 expectation is computed w/ wrong escapes - # fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7 # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 + # Almawave/Velvet-14B + # deepseek-ai/DeepSeek-R1 + # deepseek-ai/DeepSeek-R1-Zero + # fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7 + # HuggingFaceTB/SmolVLM-256M-Instruct + # HuggingFaceTB/SmolVLM-500M-Instruct + # HuggingFaceTB/SmolVLM-Instruct + # meta-llama/Llama-3.2-11B-Vision-Instruct + # unsloth/DeepSeek-R1 ) -if(NOT WIN32) - list(APPEND MODEL_IDS +if(WIN32) + list(REMOVE_ITEM MODEL_IDS # Needs investigation (https://github.com/google/minja/issues/40) - CohereForAI/c4ai-command-r7b-12-2024 # Gated - deepseek-ai/deepseek-coder-33b-instruct - deepseek-ai/DeepSeek-Coder-V2-Instruct - deepseek-ai/DeepSeek-V2.5 - deepseek-ai/DeepSeek-R1-Distill-Llama-8B - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B - deepseek-ai/DeepSeek-R1-Distill-Qwen-32B + CohereForAI/c4ai-command-r7b-12-2024 ) endif() From dee1b8921ccdc51846080fda5299bae2b592d354 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Tue, 18 Feb 2025 17:27:27 -0500 Subject: [PATCH 38/59] Use the canonical #include location for json.hpp (#55) use the canonical #include location for json.hpp --- CMakeLists.txt | 1 - README.md | 2 +- include/minja/chat-template.hpp | 4 +++- include/minja/minja.hpp | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bd11142..62d2fc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,7 +91,6 @@ message(STATUS "${fuzztest_BINARY_DIR}: ${${fuzztest_BINARY_DIR}}") include_directories( include/minja ${json_SOURCE_DIR}/include - ${json_SOURCE_DIR}/include/nlohmann ) add_subdirectory(examples) diff --git a/README.md b/README.md index d45c45e..eb17ef9 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ It is **not general purpose**: it includes just what’s needed for actual chat ## Usage: -This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++17 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json)'s `json.hpp` in your include path. +This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++17 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json) in your include path. See API in [minja/minja.hpp](./include/minja/minja.hpp) and [minja/chat-template.hpp](./include/minja/chat-template.hpp) (experimental). diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 882ba41..c74f378 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -9,10 +9,12 @@ #pragma once #include "minja.hpp" -#include + #include #include +#include + using json = nlohmann::ordered_json; namespace minja { diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index c58dd66..e6a5e25 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -16,7 +16,8 @@ #include #include #include -#include + +#include using json = nlohmann::ordered_json; From 8a76f7815e8a3ae00bd233c2b5a8b7d4e86564ec Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 6 Mar 2025 21:45:48 +0000 Subject: [PATCH 39/59] Support Qwen/QwQ-32B (needed str.{split,lstrip}) (#56) * Add str.split(sep), .strip([chars]), .rstrip([chars]), .lstrip([chars]) * test Qwen/QwQ-32B template --- include/minja/minja.hpp | 42 ++++++++++++++++++++++++++++++++++++----- tests/CMakeLists.txt | 1 + tests/test-syntax.cpp | 7 +++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index e6a5e25..12ca8af 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1379,13 +1379,27 @@ struct ArgumentsExpression { } }; -static std::string strip(const std::string & s) { - auto start = s.find_first_not_of(" \t\n\r"); +static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) { + auto charset = chars.empty() ? " \t\n\r" : chars; + auto start = left ? s.find_first_not_of(charset) : 0; if (start == std::string::npos) return ""; - auto end = s.find_last_not_of(" \t\n\r"); + auto end = right ? s.find_last_not_of(charset) : s.size() - 1; return s.substr(start, end - start + 1); } +static std::vector split(const std::string & s, const std::string & sep) { + std::vector result; + size_t start = 0; + size_t end = s.find(sep); + while (end != std::string::npos) { + result.push_back(s.substr(start, end - start)); + start = end + sep.length(); + end = s.find(sep, start); + } + result.push_back(s.substr(start)); + return result; +} + static std::string capitalize(const std::string & s) { if (s.empty()) return s; auto result = s; @@ -1468,8 +1482,26 @@ class MethodCallExpr : public Expression { } else if (obj.is_string()) { auto str = obj.get(); if (method->get_name() == "strip") { - vargs.expectArgs("strip method", {0, 0}, {0, 0}); - return Value(strip(str)); + vargs.expectArgs("strip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars)); + } else if (method->get_name() == "lstrip") { + vargs.expectArgs("lstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ true, /* right= */ false)); + } else if (method->get_name() == "rstrip") { + vargs.expectArgs("rstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ false, /* right= */ true)); + } else if (method->get_name() == "split") { + vargs.expectArgs("split method", {1, 1}, {0, 0}); + auto sep = vargs.args[0].get(); + auto parts = split(str, sep); + Value result = Value::array(); + for (const auto& part : parts) { + result.push_back(Value(part)); + } + return result; } else if (method->get_name() == "capitalize") { vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); return Value(capitalize(str)); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 842ae62..8475fd4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -263,6 +263,7 @@ set(MODEL_IDS Qwen/Qwen2.5-VL-3B-Instruct Qwen/Qwen2.5-VL-72B-Instruct Qwen/Qwen2.5-VL-7B-Instruct + Qwen/QwQ-32B Qwen/QwQ-32B-Preview rubenroy/Zurich-14B-GCv2-5m rubenroy/Zurich-7B-GCv2-5m diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index b4f4673..fdc7ca5 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -74,6 +74,13 @@ TEST(SyntaxTest, SimpleCases) { return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); }; + EXPECT_EQ("a", render("{{ ' a '.strip() }}", {}, {})); + EXPECT_EQ("a ", render("{{ ' a '.lstrip() }}", {}, {})); + EXPECT_EQ(" a", render("{{ ' a '.rstrip() }}", {}, {})); + EXPECT_EQ("bcXYZab", render("{{ 'abcXYZabc'.strip('ac') }}", {}, {})); + + EXPECT_EQ(R"(["a", "b"])", render("{{ 'a b'.split(' ') | tojson }}", {}, {})); + EXPECT_EQ( "Ok", render("{{ 'ok'.capitalize() }}", {}, {})); From 2b4a2c75a67107102764c5805f3c50c6da87efe2 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 31 Mar 2025 01:15:31 +0100 Subject: [PATCH 40/59] Fix some clang-format lints (#57) * address some clang-format lints * Fix tool response polyfill regression * Fix more lints * rm -exclude-header-filter * disable clang-analyzer-cplusplus.StringChecker --- CMakeLists.txt | 16 +++ README.md | 2 + include/minja/chat-template.hpp | 18 ++- include/minja/minja.hpp | 191 +++++++++++++++++------------- tests/test-supported-template.cpp | 18 +-- 5 files changed, 145 insertions(+), 100 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 62d2fc6..739c635 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,22 @@ project(minja VERSION 1.0.0 LANGUAGES CXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# Test if clang-tidy is available +find_program(CLANG_TIDY_EXE NAMES "clang-tidy") +if (CLANG_TIDY_EXE) + message(STATUS "clang-tidy found: ${CLANG_TIDY_EXE}") + set(CMAKE_CXX_CLANG_TIDY + clang-tidy; + -header-filter=include/minja/.*; + # https://clang.llvm.org/extra/clang-tidy/checks/list.html + # TODO: enable more / disable less checks: google-*,misc-*,modernize-*,performance-* + -checks=-*,clang-analyzer-*,clang-diagnostic-*,cppcoreguideline-*,bugprone-*,-bugprone-suspicious-include,-bugprone-assignment-in-if-condition,-bugprone-narrowing-conversions,-bugprone-easily-swappable-parameters,-bugprone-inc-dec-in-conditions,-bugprone-exception-escape,-clang-analyzer-cplusplus.StringChecker; + -warnings-as-errors=*; + ) +else() + message(STATUS "clang-tidy not found") +endif() + if (MSVC) set(MINJA_FUZZTEST_ENABLED_DEFAULT OFF) set(MINJA_USE_VENV_DEFAULT OFF) diff --git a/README.md b/README.md index eb17ef9..6f0ddf8 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,8 @@ Main limitations (non-exhaustive list): ctest --test-dir build -j --output-on-failure ``` +- Bonus: install `clang-tidy` before building (on MacOS: `brew install llvm ; sudo ln -s "$(brew --prefix llvm)/bin/clang-tidy" "/usr/local/bin/clang-tidy"`) + - Fuzzing tests - Note: `fuzztest` **[doesn't work](https://github.com/google/fuzztest/issues/179)** natively on Windows or MacOS. diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index c74f378..60d998d 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -10,6 +10,13 @@ #include "minja.hpp" +#include +#include +#include +#include +#include +#include +#include #include #include @@ -427,7 +434,7 @@ class chat_template { auto obj = json { {"tool_calls", tool_calls}, }; - if (!content.is_null() && content != "") { + if (!content.is_null() && !content.empty()) { obj["content"] = content; } message["content"] = obj.dump(2); @@ -437,13 +444,12 @@ class chat_template { if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { - {"tool_response", { - {"content", message.at("content")}, - }}, + {"tool_response", json::object()}, }; if (message.contains("name")) { - obj["tool_response"]["name"] = message.at("name"); + obj["tool_response"]["tool"] = message.at("name"); } + obj["tool_response"]["content"] = message.at("content"); if (message.contains("tool_call_id")) { obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); } @@ -512,7 +518,7 @@ class chat_template { static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { json messages_with_system = messages; - if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 12ca8af..323b0fd 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -8,14 +8,25 @@ // SPDX-License-Identifier: MIT #pragma once +#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include +#include +#include #include -#include +#include #include +#include +#include +#include #include +#include +#include #include @@ -732,51 +743,51 @@ class TemplateToken { struct TextTemplateToken : public TemplateToken { std::string text; - TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} + TextTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, loc, pre, post), text(t) {} }; struct ExpressionTemplateToken : public TemplateToken { std::shared_ptr expr; - ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} + ExpressionTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, loc, pre, post), expr(std::move(e)) {} }; struct IfTemplateToken : public TemplateToken { std::shared_ptr condition; - IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} + IfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, loc, pre, post), condition(std::move(c)) {} }; struct ElifTemplateToken : public TemplateToken { std::shared_ptr condition; - ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} + ElifTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, loc, pre, post), condition(std::move(c)) {} }; struct ElseTemplateToken : public TemplateToken { - ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} + ElseTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, loc, pre, post) {} }; struct EndIfTemplateToken : public TemplateToken { - EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} + EndIfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, loc, pre, post) {} }; struct MacroTemplateToken : public TemplateToken { std::shared_ptr name; Expression::Parameters params; - MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) - : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} + MacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, loc, pre, post), name(std::move(n)), params(std::move(p)) {} }; struct EndMacroTemplateToken : public TemplateToken { - EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} + EndMacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, loc, pre, post) {} }; struct FilterTemplateToken : public TemplateToken { std::shared_ptr filter; - FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) - : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} + FilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, loc, pre, post), filter(std::move(filter)) {} }; struct EndFilterTemplateToken : public TemplateToken { - EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} + EndFilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, loc, pre, post) {} }; struct ForTemplateToken : public TemplateToken { @@ -784,38 +795,38 @@ struct ForTemplateToken : public TemplateToken { std::shared_ptr iterable; std::shared_ptr condition; bool recursive; - ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + ForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, std::shared_ptr && c, bool r) - : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} + : TemplateToken(Type::For, loc, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} }; struct EndForTemplateToken : public TemplateToken { - EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} + EndForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, loc, pre, post) {} }; struct GenerationTemplateToken : public TemplateToken { - GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} + GenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, loc, pre, post) {} }; struct EndGenerationTemplateToken : public TemplateToken { - EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} + EndGenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, loc, pre, post) {} }; struct SetTemplateToken : public TemplateToken { std::string ns; std::vector var_names; std::shared_ptr value; - SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} + SetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, loc, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} }; struct EndSetTemplateToken : public TemplateToken { - EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} + EndSetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, loc, pre, post) {} }; struct CommentTemplateToken : public TemplateToken { std::string text; - CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} + CommentTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, loc, pre, post), text(t) {} }; enum class LoopControlType { Break, Continue }; @@ -831,7 +842,7 @@ class LoopControlException : public std::runtime_error { struct LoopControlTemplateToken : public TemplateToken { LoopControlType control_type; - LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} + LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} }; class TemplateNode { @@ -869,8 +880,8 @@ class TemplateNode { class SequenceNode : public TemplateNode { std::vector> children; public: - SequenceNode(const Location & location, std::vector> && c) - : TemplateNode(location), children(std::move(c)) {} + SequenceNode(const Location & loc, std::vector> && c) + : TemplateNode(loc), children(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& child : children) child->render(out, context); } @@ -879,7 +890,7 @@ class SequenceNode : public TemplateNode { class TextNode : public TemplateNode { std::string text; public: - TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + TextNode(const Location & loc, const std::string& t) : TemplateNode(loc), text(t) {} void do_render(std::ostringstream & out, const std::shared_ptr &) const override { out << text; } @@ -888,7 +899,7 @@ class TextNode : public TemplateNode { class ExpressionNode : public TemplateNode { std::shared_ptr expr; public: - ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + ExpressionNode(const Location & loc, std::shared_ptr && e) : TemplateNode(loc), expr(std::move(e)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); auto result = expr->evaluate(context); @@ -905,8 +916,8 @@ class ExpressionNode : public TemplateNode { class IfNode : public TemplateNode { std::vector, std::shared_ptr>> cascade; public: - IfNode(const Location & location, std::vector, std::shared_ptr>> && c) - : TemplateNode(location), cascade(std::move(c)) {} + IfNode(const Location & loc, std::vector, std::shared_ptr>> && c) + : TemplateNode(loc), cascade(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& branch : cascade) { auto enter_branch = true; @@ -925,7 +936,7 @@ class IfNode : public TemplateNode { class LoopControlNode : public TemplateNode { LoopControlType control_type_; public: - LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} + LoopControlNode(const Location & loc, LoopControlType control_type) : TemplateNode(loc), control_type_(control_type) {} void do_render(std::ostringstream &, const std::shared_ptr &) const override { throw LoopControlException(control_type_); } @@ -939,9 +950,9 @@ class ForNode : public TemplateNode { bool recursive; std::shared_ptr else_body; public: - ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + ForNode(const Location & loc, std::vector && var_names, std::shared_ptr && iterable, std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) - : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + : TemplateNode(loc), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { // https://jinja.palletsprojects.com/en/3.0.x/templates/#for @@ -1026,8 +1037,8 @@ class MacroNode : public TemplateNode { std::shared_ptr body; std::unordered_map named_param_positions; public: - MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) - : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + MacroNode(const Location & loc, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(loc), name(std::move(n)), params(std::move(p)), body(std::move(b)) { for (size_t i = 0; i < params.size(); ++i) { const auto & name = params[i].first; if (!name.empty()) { @@ -1073,8 +1084,8 @@ class FilterNode : public TemplateNode { std::shared_ptr body; public: - FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) - : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + FilterNode(const Location & loc, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(loc), filter(std::move(f)), body(std::move(b)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { if (!filter) throw std::runtime_error("FilterNode.filter is null"); @@ -1096,8 +1107,8 @@ class SetNode : public TemplateNode { std::vector var_names; std::shared_ptr value; public: - SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + SetNode(const Location & loc, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(loc), ns(ns), var_names(vns), value(std::move(v)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!value) throw std::runtime_error("SetNode.value is null"); if (!ns.empty()) { @@ -1119,8 +1130,8 @@ class SetTemplateNode : public TemplateNode { std::string name; std::shared_ptr template_value; public: - SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) - : TemplateNode(location), name(name), template_value(std::move(tv)) {} + SetTemplateNode(const Location & loc, const std::string & name, std::shared_ptr && tv) + : TemplateNode(loc), name(name), template_value(std::move(tv)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); Value value { template_value->render(context) }; @@ -1133,8 +1144,8 @@ class IfExpr : public Expression { std::shared_ptr then_expr; std::shared_ptr else_expr; public: - IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) - : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + IfExpr(const Location & loc, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(loc), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!condition) throw std::runtime_error("IfExpr.condition is null"); if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); @@ -1151,16 +1162,16 @@ class IfExpr : public Expression { class LiteralExpr : public Expression { Value value; public: - LiteralExpr(const Location & location, const Value& v) - : Expression(location), value(v) {} + LiteralExpr(const Location & loc, const Value& v) + : Expression(loc), value(v) {} Value do_evaluate(const std::shared_ptr &) const override { return value; } }; class ArrayExpr : public Expression { std::vector> elements; public: - ArrayExpr(const Location & location, std::vector> && e) - : Expression(location), elements(std::move(e)) {} + ArrayExpr(const Location & loc, std::vector> && e) + : Expression(loc), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::array(); for (const auto& e : elements) { @@ -1174,8 +1185,8 @@ class ArrayExpr : public Expression { class DictExpr : public Expression { std::vector, std::shared_ptr>> elements; public: - DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) - : Expression(location), elements(std::move(e)) {} + DictExpr(const Location & loc, std::vector, std::shared_ptr>> && e) + : Expression(loc), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::object(); for (const auto& [key, value] : elements) { @@ -1190,8 +1201,8 @@ class DictExpr : public Expression { class SliceExpr : public Expression { public: std::shared_ptr start, end; - SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) - : Expression(location), start(std::move(s)), end(std::move(e)) {} + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e) + : Expression(loc), start(std::move(s)), end(std::move(e)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); } @@ -1201,8 +1212,8 @@ class SubscriptExpr : public Expression { std::shared_ptr base; std::shared_ptr index; public: - SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) - : Expression(location), base(std::move(b)), index(std::move(i)) {} + SubscriptExpr(const Location & loc, std::shared_ptr && b, std::shared_ptr && i) + : Expression(loc), base(std::move(b)), index(std::move(i)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!base) throw std::runtime_error("SubscriptExpr.base is null"); if (!index) throw std::runtime_error("SubscriptExpr.index is null"); @@ -1244,8 +1255,8 @@ class UnaryOpExpr : public Expression { enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; std::shared_ptr expr; Op op; - UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) - : Expression(location), expr(std::move(e)), op(o) {} + UnaryOpExpr(const Location & loc, std::shared_ptr && e, Op o) + : Expression(loc), expr(std::move(e)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); auto e = expr->evaluate(context); @@ -1270,8 +1281,8 @@ class BinaryOpExpr : public Expression { std::shared_ptr right; Op op; public: - BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) - : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + BinaryOpExpr(const Location & loc, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(loc), left(std::move(l)), right(std::move(r)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); @@ -1428,8 +1439,8 @@ class MethodCallExpr : public Expression { std::shared_ptr method; ArgumentsExpression args; public: - MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + MethodCallExpr(const Location & loc, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(loc), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!object) throw std::runtime_error("MethodCallExpr.object is null"); if (!method) throw std::runtime_error("MethodCallExpr.method is null"); @@ -1527,8 +1538,8 @@ class CallExpr : public Expression { public: std::shared_ptr object; ArgumentsExpression args; - CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), args(std::move(a)) {} + CallExpr(const Location & loc, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(loc), object(std::move(obj)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!object) throw std::runtime_error("CallExpr.object is null"); auto obj = object->evaluate(context); @@ -1543,8 +1554,8 @@ class CallExpr : public Expression { class FilterExpr : public Expression { std::vector> parts; public: - FilterExpr(const Location & location, std::vector> && p) - : Expression(location), parts(std::move(p)) {} + FilterExpr(const Location & loc, std::vector> && p) + : Expression(loc), parts(std::move(p)) {} Value do_evaluate(const std::shared_ptr & context) const override { Value result; bool first = true; @@ -2461,7 +2472,7 @@ class Parser { static std::regex leading_space_regex(R"(^\s+)"); text = std::regex_replace(text, leading_space_regex, ""); } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { - if (text.length() > 0 && text[0] == '\n') { + if (!text.empty() && text[0] == '\n') { text.erase(0, 1); } } @@ -2539,7 +2550,7 @@ class Parser { TemplateTokenIterator begin = tokens.begin(); auto it = begin; TemplateTokenIterator end = tokens.end(); - return parser.parseTemplate(begin, it, end, /* full= */ true); + return parser.parseTemplate(begin, it, end, /* fully= */ true); } }; @@ -2578,7 +2589,7 @@ inline std::shared_ptr Context::builtins() { throw std::runtime_error(args.at("message").get()); })); globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { - return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + return Value(args.at("value").dump(args.get("indent", -1), /* to_json= */ true)); })); globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { auto items = Value::array(); @@ -2600,7 +2611,7 @@ inline std::shared_ptr Context::builtins() { globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { auto items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not a list"); - if (items.size() == 0) return Value(); + if (items.empty()) return Value(); return items.at(items.size() - 1); })); globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { @@ -2744,12 +2755,17 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; - if (items.is_null()) + if (items.is_null()) { return Value::array(); - if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + } + if (!items.is_array()) { + throw std::runtime_error("object is not iterable: " + items.dump()); + } auto filter_fn = context->get(args.args[1]); - if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + if (filter_fn.is_null()) { + throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + } auto filter_args = Value::array(); for (size_t i = 2, n = args.args.size(); i < n; i++) { @@ -2871,20 +2887,25 @@ inline std::shared_ptr Context::builtins() { auto v = arg.get(); startEndStep[i] = v; param_set[i] = true; - } } - for (auto & [name, value] : args.kwargs) { - size_t i; - if (name == "start") i = 0; - else if (name == "end") i = 1; - else if (name == "step") i = 2; - else throw std::runtime_error("Unknown argument " + name + " for function range"); - - if (param_set[i]) { - throw std::runtime_error("Duplicate argument " + name + " for function range"); - } - startEndStep[i] = value.get(); - param_set[i] = true; + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") { + i = 0; + } else if (name == "end") { + i = 1; + } else if (name == "step") { + i = 2; + } else { + throw std::runtime_error("Unknown argument " + name + " for function range"); + } + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; } if (!param_set[1]) { throw std::runtime_error("Missing required argument 'end' for function range"); diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 965375f..ad3a711 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -84,10 +84,10 @@ static json caps_to_json(const minja::chat_template_caps &caps) { int main(int argc, char *argv[]) { if (argc != 5) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + std::cerr << "Usage: " << argv[0] << " " << "\n"; for (int i = 0; i < argc; i++) { - std::cerr << "argv[" << i << "] = " << argv[i] << std::endl; + std::cerr << "argv[" << i << "] = " << argv[i] << "\n"; } return 1; } @@ -102,7 +102,7 @@ int main(int argc, char *argv[]) { if (ctx_file == "n/a") { - std::cout << "# Skipping template: " << tmpl_file << "\n" << tmpl_str << std::endl; + std::cout << "# Skipping template: " << tmpl_file << "\n" << tmpl_str << "\n"; return 127; } @@ -121,8 +121,8 @@ int main(int argc, char *argv[]) { try { expected = minja::normalize_newlines(read_file(golden_file)); } catch (const std::exception &e) { - std::cerr << "Failed to read golden file: " << golden_file << std::endl; - std::cerr << e.what() << std::endl; + std::cerr << "Failed to read golden file: " << golden_file << "\n"; + std::cerr << e.what() << "\n"; return 1; } @@ -148,14 +148,14 @@ int main(int argc, char *argv[]) { try { actual = tmpl.apply(inputs); } catch (const std::exception &e) { - std::cerr << "Error applying template: " << e.what() << std::endl; + std::cerr << "Error applying template: " << e.what() << "\n"; return 1; } if (expected != actual) { if (getenv("WRITE_GOLDENS")) { write_file(golden_file, actual); - std::cerr << "Updated golden file: " << golden_file << std::endl; + std::cerr << "Updated golden file: " << golden_file << "\n"; } else { assert_equals(expected, actual); } @@ -169,10 +169,10 @@ int main(int argc, char *argv[]) { assert_equals(expected_caps, caps); #endif - std::cout << "Test passed successfully." << std::endl; + std::cout << "Test passed successfully." << "\n"; return 0; } catch (const std::exception &e) { - std::cerr << "Test failed: " << e.what() << std::endl; + std::cerr << "Test failed: " << e.what() << "\n"; return 1; } } From 84187ba510bea5d03a599b47d3d5650c7c43bf51 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 31 Mar 2025 01:59:21 +0100 Subject: [PATCH 41/59] Add `upper` filter (to support inclusionAI/Ling-Coder-lite) (#58) * Add upper filter * Test inclusionAI/Ling-Coder-lite template --- include/minja/minja.hpp | 20 ++++++++++++-------- tests/CMakeLists.txt | 1 + tests/test-syntax.cpp | 5 ++++- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 323b0fd..f1fd579 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2618,14 +2618,18 @@ inline std::shared_ptr Context::builtins() { auto & text = args.at("text"); return text.is_null() ? text : Value(strip(text.get())); })); - globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { - auto text = args.at("text"); - if (text.is_null()) return text; - std::string res; - auto str = text.get(); - std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); - return Value(res); - })); + auto char_transform_function = [](const std::string & name, const std::function & fn) { + return simple_function(name, { "text" }, [=](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), fn); + return Value(res); + }); + }; + globals.set("lower", char_transform_function("lower", ::tolower)); + globals.set("upper", char_transform_function("upper", ::toupper)); globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { args.expectArgs("default", {2, 3}, {0, 1}); auto & value = args.args[0]; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8475fd4..d7e6fa5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -137,6 +137,7 @@ set(MODEL_IDS huihui-ai/Qwen2.5-14B-Instruct-1M-abliterated ibm-granite/granite-3.1-8b-instruct Ihor/Text2Graph-R1-Qwen2.5-0.5b + inclusionAI/Ling-Coder-lite indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 Infinigence/Megrez-3B-Instruct inflatebot/MN-12B-Mag-Mell-R1 diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index fdc7ca5..5b894d3 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -88,7 +88,7 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "ok", render("{# Hey\nHo #}{#- Multiline...\nComments! -#}{{ 'ok' }}{# yo #}", {}, {})); - + EXPECT_EQ( " b", render(R"( {% set _ = 1 %} {% set _ = 2 %}b)", {}, lstrip_trim_blocks)); @@ -130,6 +130,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "abc", render("{{ 'AbC' | lower }}", {}, {})); + EXPECT_EQ( + "ME", + render("{{ 'me' | upper }}", {}, {})); EXPECT_EQ( "the default1", render("{{ foo | default('the default') }}{{ 1 | default('nope') }}", {}, {})); From b810bce3f3497d253ff135dc9072333a943d3331 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 7 Apr 2025 16:08:09 -0700 Subject: [PATCH 42/59] Test strftime_now (+ nit typo fixes) (#59) * Test strftime_now in new test-chat-template * opportunistic typo fixes (Unashable -> Unhashable) * avoid gtest regex weirdness on win32 --- include/minja/minja.hpp | 10 +++--- tests/CMakeLists.txt | 19 ++++++++++ tests/test-chat-template.cpp | 69 ++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 tests/test-chat-template.cpp diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index f1fd579..4b548da 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -233,7 +233,7 @@ class Value : public std::enable_shared_from_this { } } else if (is_object()) { if (!index.is_hashable()) - throw std::runtime_error("Unashable type: " + index.dump()); + throw std::runtime_error("Unhashable type: " + index.dump()); auto it = object_->find(index.primitive_); if (it == object_->end()) throw std::runtime_error("Key not found: " + index.dump()); @@ -252,7 +252,7 @@ class Value : public std::enable_shared_from_this { auto index = key.get(); return array_->at(index < 0 ? array_->size() + index : index); } else if (object_) { - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); auto it = object_->find(key.primitive_); if (it == object_->end()) return Value(); return it->second; @@ -261,7 +261,7 @@ class Value : public std::enable_shared_from_this { } void set(const Value& key, const Value& value) { if (!object_) throw std::runtime_error("Value is not an object: " + dump()); - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); (*object_)[key.primitive_] = value; } Value call(const std::shared_ptr & context, ArgumentsValue & args) const { @@ -398,7 +398,7 @@ class Value : public std::enable_shared_from_this { } return false; } else if (object_) { - if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump()); return object_->find(value.primitive_) != object_->end(); } else { throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); @@ -416,7 +416,7 @@ class Value : public std::enable_shared_from_this { return const_cast(this)->at(index); } Value& at(const Value & index) { - if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); if (is_array()) return array_->at(index.get()); if (is_object()) return object_->at(index.primitive_); throw std::runtime_error("Value is not an array or object: " + dump()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d7e6fa5..051c5c2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,6 +18,22 @@ target_link_libraries(test-syntax PRIVATE gmock ) +if (WIN32) + message(STATUS "Skipping test-chat-template on Win32") +else() + add_executable(test-chat-template test-chat-template.cpp) + target_compile_features(test-chat-template PUBLIC cxx_std_17) + if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-chat-template PUBLIC _CRT_SECURE_NO_WARNINGS) + target_compile_options(gtest PRIVATE -Wno-language-extension-token) + endif() + target_link_libraries(test-chat-template PRIVATE + nlohmann_json::nlohmann_json + gtest_main + gmock + ) +endif() + add_executable(test-polyfills test-polyfills.cpp) target_compile_features(test-polyfills PUBLIC cxx_std_17) if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") @@ -31,6 +47,9 @@ target_link_libraries(test-polyfills PRIVATE ) if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) + if (NOT WIN32) + gtest_discover_tests(test-chat-template) + endif() add_test(NAME test-polyfills COMMAND test-polyfills) set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endif() diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp new file mode 100644 index 0000000..7e33191 --- /dev/null +++ b/tests/test-chat-template.cpp @@ -0,0 +1,69 @@ + +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#include "chat-template.hpp" +#include "gtest/gtest.h" +#include +#include + +#include +#include +#include + +using namespace minja; +using namespace testing; + +static std::string render_python(const std::string & template_str, const chat_template_inputs & inputs) { + json bindings = inputs.extra_context; + bindings["messages"] = inputs.messages; + bindings["tools"] = inputs.tools; + bindings["add_generation_prompt"] = inputs.add_generation_prompt; + json data { + {"template", template_str}, + {"bindings", bindings}, + {"options", { + {"trim_blocks", true}, + {"lstrip_blocks", true}, + {"keep_trailing_newline", false}, + }}, + }; + { + std::ofstream of("data.json"); + of << data.dump(2); + of.close(); + } + + auto pyExeEnv = getenv("PYTHON_EXECUTABLE"); + std::string pyExe = pyExeEnv ? pyExeEnv : "python3"; + + std::remove("out.txt"); + auto res = std::system((pyExe + " -m scripts.render data.json out.txt").c_str()); + if (res != 0) { + throw std::runtime_error("Failed to run python script with data: " + data.dump(2)); + } + + std::ifstream f("out.txt"); + std::string out((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + return out; +} + +static std::string render(const std::string & template_str, const chat_template_inputs & inputs, const chat_template_options & opts) { + if (getenv("USE_JINJA2")) { + return render_python(template_str, inputs); + } + chat_template tmpl( + template_str, + "", + ""); + return tmpl.apply(inputs, opts); +} + +TEST(ChatTemplateTest, SimpleCases) { + EXPECT_THAT(render("{{ strftime_now('%Y-%m-%d %H:%M:%S') }}", {}, {}), MatchesRegex(R"([0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2})")); +} From ea5fe1bc5728d7dad62465e8cf729ab5ce2c92b3 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 7 Apr 2025 17:06:01 -0700 Subject: [PATCH 43/59] Update README.md (#60) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6f0ddf8..80aab6e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ _**This is not an official Google product**_ -Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016) and [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433)). +Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016), [Jan](https://jan.ai/) (through [cortex.cpp](https://github.com/menloresearch/cortex.cpp/pull/1814)) and [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433)). It is **not general purpose**: it includes just what’s needed for actual chat templates (very limited set of filters, tests and language features). Users with different needs should look at third-party alternatives such as [Jinja2Cpp](https://github.com/jinja2cpp/Jinja2Cpp), [Jinja2CppLight](https://github.com/hughperkins/Jinja2CppLight), or [inja](https://github.com/pantor/inja) (none of which we endorse). From fcb5a0d3380fd1147c15be332b185a81e7aeae65 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 15 Apr 2025 18:08:30 +0100 Subject: [PATCH 44/59] fix more clangd lints (#61) --- include/minja/chat-template.hpp | 2 ++ include/minja/minja.hpp | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 60d998d..142a58e 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -13,10 +13,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 4b548da..ec89bc7 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -676,8 +677,8 @@ class Expression { class VariableExpr : public Expression { std::string name; public: - VariableExpr(const Location & location, const std::string& n) - : Expression(location), name(n) {} + VariableExpr(const Location & loc, const std::string& n) + : Expression(loc), name(n) {} std::string get_name() const { return name; } Value do_evaluate(const std::shared_ptr & context) const override { if (!context->contains(name)) { From b152576afd6e5117f45923f75497005868877e21 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Fri, 25 Apr 2025 00:35:16 +0900 Subject: [PATCH 45/59] make minja packagable by setting target in cmake (#62) * make minja packagable by setting target in cmake * preventing gtest install * Add cmake instruct in README --- CMakeLists.txt | 94 ++++++++++++++++++++----------- README.md | 8 +++ examples/CMakeLists.txt | 2 +- examples/chat-template.cpp | 2 +- examples/raw.cpp | 2 +- tests/CMakeLists.txt | 10 ++-- tests/test-capabilities.cpp | 2 +- tests/test-chat-template.cpp | 2 +- tests/test-fuzz.cpp | 4 +- tests/test-polyfills.cpp | 4 +- tests/test-supported-template.cpp | 2 +- tests/test-syntax.cpp | 2 +- 12 files changed, 84 insertions(+), 50 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 739c635..6a3760c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,8 @@ cmake_policy(SET CMP0135 NEW) # https://cmake.org/cmake/help/latest/policy/CMP01 project(minja VERSION 1.0.0 LANGUAGES CXX) +add_library(minja INTERFACE) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Test if clang-tidy is available @@ -36,6 +38,8 @@ else() set(MINJA_FUZZTEST_ENABLED_DEFAULT ON) set(MINJA_USE_VENV_DEFAULT ON) endif() +option(MINJA_TEST_ENABLED "minja: Build with test(python interpreter required)" ON) +option(MINJA_EXAMPLE_ENABLED "minja: Build with example" ON) option(MINJA_FUZZTEST_ENABLED "minja: fuzztests enabled" MINJA_FUZZTEST_ENABLED_DEFAULT) option(MINJA_FUZZTEST_FUZZING_MODE "minja: run fuzztests (if enabled) in fuzzing mode" OFF) option(MINJA_USE_VENV "minja: use Python venv for build" MINJA_USE_VENV_DEFAULT) @@ -53,15 +57,20 @@ include(FetchContent) # Fetch nlohmann/json FetchContent_Declare(json URL https://github.com/nlohmann/json/archive/refs/heads/develop.zip) FetchContent_MakeAvailable(json) - -if (MINJA_FUZZTEST_ENABLED) - # Fetch google/fuzztest (and indirectly, gtest) - FetchContent_Declare(fuzztest URL https://github.com/google/fuzztest/archive/refs/heads/main.zip) - FetchContent_MakeAvailable(fuzztest) -else() - # Fetch gtest - FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/refs/heads/main.zip) - FetchContent_MakeAvailable(googletest) +target_link_libraries(minja INTERFACE nlohmann_json::nlohmann_json) + +if(MINJA_TEST_ENABLED) + if (MINJA_FUZZTEST_ENABLED) + # Fetch google/fuzztest (and indirectly, gtest) + FetchContent_Declare(fuzztest URL https://github.com/google/fuzztest/archive/refs/heads/main.zip) + FetchContent_MakeAvailable(fuzztest) + message(STATUS "${fuzztest_BINARY_DIR}: ${${fuzztest_BINARY_DIR}}") + else() + # Fetch gtest + set(INSTALL_GTEST OFF) + FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/refs/heads/main.zip) + FetchContent_MakeAvailable(googletest) + endif() endif() # Use ccache if installed @@ -77,25 +86,27 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() -set(Python_FIND_STRATEGY LOCATION CACHE STRING "Python find strategy" FORCE) -find_package(Python COMPONENTS Interpreter REQUIRED) -if(MINJA_USE_VENV) - # Create a python venv w/ the required dependencies - set(VENV_DIR "${CMAKE_BINARY_DIR}/venv") - if(WIN32) - set(VENV_PYTHON "${VENV_DIR}/Scripts/python.exe") - else() - set(VENV_PYTHON "${VENV_DIR}/bin/python") +if(MINJA_TEST_ENABLED) + set(Python_FIND_STRATEGY LOCATION CACHE STRING "Python find strategy" FORCE) + find_package(Python COMPONENTS Interpreter REQUIRED) + if(MINJA_USE_VENV) + # Create a python venv w/ the required dependencies + set(VENV_DIR "${CMAKE_BINARY_DIR}/venv") + if(WIN32) + set(VENV_PYTHON "${VENV_DIR}/Scripts/python.exe") + else() + set(VENV_PYTHON "${VENV_DIR}/bin/python") + endif() + execute_process( + COMMAND ${Python_EXECUTABLE} -m venv "${VENV_DIR}" + COMMAND_ERROR_IS_FATAL ANY) + execute_process( + COMMAND ${VENV_PYTHON} -m pip install -r "${CMAKE_SOURCE_DIR}/requirements.txt" + COMMAND_ERROR_IS_FATAL ANY) + set(Python_EXECUTABLE "${VENV_PYTHON}" CACHE FILEPATH "Path to Python executable in venv" FORCE) endif() - execute_process( - COMMAND ${Python_EXECUTABLE} -m venv "${VENV_DIR}" - COMMAND_ERROR_IS_FATAL ANY) - execute_process( - COMMAND ${VENV_PYTHON} -m pip install -r "${CMAKE_SOURCE_DIR}/requirements.txt" - COMMAND_ERROR_IS_FATAL ANY) - set(Python_EXECUTABLE "${VENV_PYTHON}" CACHE FILEPATH "Path to Python executable in venv" FORCE) + message(STATUS "Python executable: ${Python_EXECUTABLE}") endif() -message(STATUS "Python executable: ${Python_EXECUTABLE}") find_program(CPPCHECK cppcheck) if(CPPCHECK) @@ -103,14 +114,29 @@ if(CPPCHECK) message(STATUS "cppcheck found: ${CPPCHECK}") endif() -message(STATUS "${fuzztest_BINARY_DIR}: ${${fuzztest_BINARY_DIR}}") -include_directories( - include/minja - ${json_SOURCE_DIR}/include +include(GNUInstallDirs) +target_include_directories(minja INTERFACE + $ + $ ) -add_subdirectory(examples) +install(FILES + ${PROJECT_SOURCE_DIR}/include/minja/minja.hpp + ${PROJECT_SOURCE_DIR}/include/minja/chat-template.hpp + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/minja +) +install( + TARGETS minja + EXPORT "${TARGETS_EXPORT_NAME}" + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/minja # for downstream projects +) -enable_testing() -include(GoogleTest) -add_subdirectory(tests) +if(MINJA_EXAMPLE_ENABLED) + add_subdirectory(examples) +endif() + +if(MINJA_TEST_ENABLED) + enable_testing() + include(GoogleTest) + add_subdirectory(tests) +endif() diff --git a/README.md b/README.md index 80aab6e..868bcf8 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,14 @@ It is **not general purpose**: it includes just what’s needed for actual chat This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++17 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json) in your include path. +If your project is based on [cmake](https://cmake.org/), can simply import by using `FetchContent`. +``` +FetchContent_Declare(minja GIT_REPOSITORY "https://github.com/google/minja") +FetchContent_MakeAvailable(minja) + +target_link_libraries( PRIVATE minja) +``` + See API in [minja/minja.hpp](./include/minja/minja.hpp) and [minja/chat-template.hpp](./include/minja/chat-template.hpp) (experimental). For raw Jinja templating (see [examples/raw.cpp](./examples/raw.cpp)): diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 7d75913..16e8fb2 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,7 +11,7 @@ foreach(example ) add_executable(${example} ${example}.cpp) target_compile_features(${example} PUBLIC cxx_std_17) - target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json) + target_link_libraries(${example} PRIVATE minja) if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") target_compile_definitions(${example} PUBLIC _CRT_SECURE_NO_WARNINGS) endif() diff --git a/examples/chat-template.cpp b/examples/chat-template.cpp index d1838e7..792a157 100644 --- a/examples/chat-template.cpp +++ b/examples/chat-template.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include +#include #include using json = nlohmann::ordered_json; diff --git a/examples/raw.cpp b/examples/raw.cpp index d36a03a..c129449 100644 --- a/examples/raw.cpp +++ b/examples/raw.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include +#include #include using json = nlohmann::ordered_json; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 051c5c2..aa3a756 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,7 +13,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ar target_compile_options(gtest PRIVATE -Wno-language-extension-token) endif() target_link_libraries(test-syntax PRIVATE - nlohmann_json::nlohmann_json + minja gtest_main gmock ) @@ -28,7 +28,7 @@ else() target_compile_options(gtest PRIVATE -Wno-language-extension-token) endif() target_link_libraries(test-chat-template PRIVATE - nlohmann_json::nlohmann_json + minja gtest_main gmock ) @@ -41,7 +41,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ar target_compile_options(gtest PRIVATE -Wno-language-extension-token) endif() target_link_libraries(test-polyfills PRIVATE - nlohmann_json::nlohmann_json + minja gtest_main gmock ) @@ -61,7 +61,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ar target_compile_options(gtest PRIVATE -Wno-language-extension-token) endif() target_link_libraries(test-capabilities PRIVATE - nlohmann_json::nlohmann_json + minja gtest_main gmock ) @@ -77,7 +77,7 @@ target_compile_features(test-supported-template PUBLIC cxx_std_17) if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") target_compile_definitions(test-supported-template PUBLIC _CRT_SECURE_NO_WARNINGS) endif() -target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) +target_link_libraries(test-supported-template PRIVATE minja) # https://huggingface.co/models?other=conversational # https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/?types=fine-tuned%2Cchat diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp index 225581f..53e3999 100644 --- a/tests/test-capabilities.cpp +++ b/tests/test-capabilities.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "chat-template.hpp" +#include "minja/chat-template.hpp" #include #include diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 7e33191..0831275 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,7 +7,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "chat-template.hpp" +#include "minja/chat-template.hpp" #include "gtest/gtest.h" #include #include diff --git a/tests/test-fuzz.cpp b/tests/test-fuzz.cpp index c256079..7169bce 100644 --- a/tests/test-fuzz.cpp +++ b/tests/test-fuzz.cpp @@ -9,8 +9,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index d1c598b..5bc1226 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -6,14 +6,14 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "minja.hpp" +#include "minja/minja.hpp" #include #include #include #include #include -#include "chat-template.hpp" +#include "minja/chat-template.hpp" using namespace minja; diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index ad3a711..db23a4a 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "chat-template.hpp" +#include "minja/chat-template.hpp" #include #include diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 5b894d3..5bc82fa 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "minja.hpp" +#include "minja/minja.hpp" #include #include From 6a458a8fa6dc18257b4864e087919de6b754aad2 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 7 May 2025 18:12:42 +0100 Subject: [PATCH 46/59] Update README.md (#68) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 868bcf8..aa3579c 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ It is **not general purpose**: it includes just what’s needed for actual chat > [!WARNING] > TL;DR: use of Minja is *at your own risk*, and the risks are plenty! See [Security & Privacy](#security--privacy) section below. +> [!IMPORTANT] +> [@ochafik](https://github.com/ochafik) has left Google, watch out for https://github.com/ochafik/minja + [![CI](https://github.com/google/minja/actions/workflows/build.yml/badge.svg)](https://github.com/google/minja/actions/workflows/build.yml) ## Design goals: From cb5f47ecc28be1bf98d07c30195d3670ff12a543 Mon Sep 17 00:00:00 2001 From: Park Woorak Date: Thu, 15 May 2025 19:18:58 +0900 Subject: [PATCH 47/59] Add `true` and `false` as operands of `is` and `is not` (#67) * feat: add `true` and `false` as operands of `is` and `is not` * test: add tests for `is true/false` and `is not true/false` --------- Co-authored-by: Olivier Chafik --- include/minja/minja.hpp | 2 ++ tests/test-syntax.cpp | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index ec89bc7..39821cd 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1306,6 +1306,8 @@ class BinaryOpExpr : public Expression { if (name == "iterable") return l.is_iterable(); if (name == "sequence") return l.is_array(); if (name == "defined") return !l.is_null(); + if (name == "true") return l.to_bool(); + if (name == "false") return !l.to_bool(); throw std::runtime_error("Unknown type for 'is' operator: " + name); }; auto value = eval(); diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 5bc82fa..a5a2707 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -217,6 +217,18 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "False", render(R"({% set foo = true %}{{ not foo is defined }})", {}, {})); + EXPECT_EQ( + "True", + render(R"({% set foo = true %}{{ foo is true }})", {}, {})); + EXPECT_EQ( + "False", + render(R"({% set foo = true %}{{ foo is false }})", {}, {})); + EXPECT_EQ( + "True", + render(R"({% set foo = false %}{{ foo is not true }})", {}, {})); + EXPECT_EQ( + "False", + render(R"({% set foo = false %}{{ foo is not false }})", {}, {})); EXPECT_EQ( R"({"a": "b"})", render(R"({{ {"a": "b"} | tojson }})", {}, {})); From 17a767af1dc61a2e8e750a7c8688e382b1359217 Mon Sep 17 00:00:00 2001 From: Taha Yassine <40228615+taha-yassine@users.noreply.github.com> Date: Thu, 15 May 2025 12:42:31 +0200 Subject: [PATCH 48/59] Support Qwen3 (str.startswith() and [::-1]) (#66) * Add str.startswith() * Add support for step=-1 in slice * Add Qwen3 template * Clamp out-of-bounds slice indices * Simplify subscript logic + handle any (non-zero) step --------- Co-authored-by: Olivier Chafik --- include/minja/minja.hpp | 90 +++++++++++++++++++++++++++-------------- tests/CMakeLists.txt | 1 + tests/test-syntax.cpp | 12 ++++++ 3 files changed, 73 insertions(+), 30 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 39821cd..ee123a7 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1201,9 +1201,9 @@ class DictExpr : public Expression { class SliceExpr : public Expression { public: - std::shared_ptr start, end; - SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e) - : Expression(loc), start(std::move(s)), end(std::move(e)) {} + std::shared_ptr start, end, step; + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e, std::shared_ptr && st = nullptr) + : Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); } @@ -1220,18 +1220,35 @@ class SubscriptExpr : public Expression { if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + auto len = target_value.size(); + auto wrap = [len](int64_t i) -> int64_t { + if (i < 0) { + return i + len; + } + return i; + }; + int64_t step = slice->step ? slice->step->evaluate(context).get() : 1; + if (!step) { + throw std::runtime_error("slice step cannot be zero"); + } + int64_t start = slice->start ? wrap(slice->start->evaluate(context).get()) : (step < 0 ? len - 1 : 0); + int64_t end = slice->end ? wrap(slice->end->evaluate(context).get()) : (step < 0 ? -1 : len); if (target_value.is_string()) { std::string s = target_value.get(); - if (start < 0) start = s.size() + start; - if (end < 0) end = s.size() + end; - return s.substr(start, end - start); - } else if (target_value.is_array()) { - if (start < 0) start = target_value.size() + start; - if (end < 0) end = target_value.size() + end; + + std::string result; + if (start < end && step == 1) { + result = s.substr(start, end - start); + } else { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { + result += s[i]; + } + } + return result; + + } else if (target_value.is_array()) { auto result = Value::array(); - for (auto i = start; i < end; ++i) { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { result.push_back(target_value.at(i)); } return result; @@ -1523,6 +1540,10 @@ class MethodCallExpr : public Expression { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "startswith") { + vargs.expectArgs("startswith method", {1, 1}, {0, 0}); + auto prefix = vargs.args[0].get(); + return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin()); } else if (method->get_name() == "title") { vargs.expectArgs("title method", {0, 0}, {0, 0}); auto res = str; @@ -2085,28 +2106,37 @@ class Parser { while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { if (!consumeToken("[").empty()) { - std::shared_ptr index; + std::shared_ptr index; + auto slice_loc = get_location(); + std::shared_ptr start, end, step; + bool has_first_colon = false, has_second_colon = false; + + if (!peekSymbols({ ":" })) { + start = parseExpression(); + } + + if (!consumeToken(":").empty()) { + has_first_colon = true; + if (!peekSymbols({ ":", "]" })) { + end = parseExpression(); + } if (!consumeToken(":").empty()) { - auto slice_end = parseExpression(); - index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); - } else { - auto slice_start = parseExpression(); - if (!consumeToken(":").empty()) { - consumeSpaces(); - if (peekSymbols({ "]" })) { - index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); - } else { - auto slice_end = parseExpression(); - index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); - } - } else { - index = std::move(slice_start); + has_second_colon = true; + if (!peekSymbols({ "]" })) { + step = parseExpression(); } } - if (!index) throw std::runtime_error("Empty index in subscript"); - if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + } + + if ((has_first_colon || has_second_colon) && (start || end || step)) { + index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); + } else { + index = std::move(start); + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); - value = std::make_shared(value->location, std::move(value), std::move(index)); + value = std::make_shared(value->location, std::move(value), std::move(index)); } else if (!consumeToken(".").empty()) { auto identifier = parseIdentifier(); if (!identifier) throw std::runtime_error("Expected identifier in subscript"); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index aa3a756..09323b3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -318,6 +318,7 @@ set(MODEL_IDS ValiantLabs/Llama3.1-8B-Enigma xwen-team/Xwen-72B-Chat xwen-team/Xwen-7B-Chat + Qwen/Qwen3-4B # Broken, TODO: # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index a5a2707..a628aa2 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -184,6 +184,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "1", render(R"({{ 1 | safe }})", {}, {})); + EXPECT_EQ( + "True,False", + render(R"({{ 'abc'.startswith('ab') }},{{ ''.startswith('a') }})", {}, {})); EXPECT_EQ( "True,False", render(R"({{ 'abc'.endswith('bc') }},{{ ''.endswith('a') }})", {}, {})); @@ -477,6 +480,15 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "[1, 2, 3][0, 1][1, 2]", render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {})); + EXPECT_EQ( + "123;01;12", + render("{% set x = '0123' %}{{ x[1:] }};{{ x[:2] }};{{ x[1:3] }}", {}, {})); + EXPECT_EQ( + "[3, 2, 1, 0][3, 2, 1][2, 1, 0][2, 1][0, 2][3, 1][2, 0]", + render("{% set x = [0, 1, 2, 3] %}{{ x[::-1] }}{{ x[:0:-1] }}{{ x[2::-1] }}{{ x[2:0:-1] }}{{ x[::2] }}{{ x[::-2] }}{{ x[-2::-2] }}", {}, {})); + EXPECT_EQ( + "3210;321;210;21;02;31;20", + render("{% set x = '0123' %}{{ x[::-1] }};{{ x[:0:-1] }};{{ x[2::-1] }};{{ x[2:0:-1] }};{{ x[::2] }};{{ x[::-2] }};{{ x[-2::-2] }}", {}, {})); EXPECT_EQ( "a", render("{{ ' a ' | trim }}", {}, {})); From f06140fa52fd140fe38e531ec373d8dc9c86aa06 Mon Sep 17 00:00:00 2001 From: Park Woorak Date: Thu, 15 May 2025 19:43:52 +0900 Subject: [PATCH 49/59] Consider `"tool_calls"` instead of `"content"` in messages (#63) * fix: require one of 'content' or 'tool_calls' in messages * fix: check content is given during polyfill tool_calls --------- Co-authored-by: Olivier Chafik --- include/minja/chat-template.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 142a58e..ab5b521 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -395,8 +395,8 @@ class chat_template { for (const auto & message_ : adjusted_messages) { auto message = message_; - if (!message.contains("role") || !message.contains("content")) { - throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) { + throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump()); } std::string role = message.at("role"); @@ -417,7 +417,6 @@ class chat_template { } } if (polyfill_tool_calls) { - auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { if (tool_call.at("type") != "function") { @@ -436,8 +435,11 @@ class chat_template { auto obj = json { {"tool_calls", tool_calls}, }; - if (!content.is_null() && !content.empty()) { - obj["content"] = content; + if (message.contains("content")) { + auto content = message.at("content"); + if (!content.is_null() && !content.empty()) { + obj["content"] = content; + } } message["content"] = obj.dump(2); message.erase("tool_calls"); From da98a14ae1f11d193f1c2996088216b7def5ff70 Mon Sep 17 00:00:00 2001 From: Park Woorak Date: Thu, 10 Jul 2025 19:25:08 +0900 Subject: [PATCH 50/59] fix: use `""` instead of `nullptr` for templates requiring non null content (#72) fix: use empty string instead of nullptr for templates requiring non null --- include/minja/chat-template.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index ab5b521..3bc7c77 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -162,10 +162,14 @@ class chat_template { }), false); caps_.supports_tools = contains(out, "some_tool"); + auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); + auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); + auto make_tool_calls_msg = [&](const json & tool_calls) { return json { {"role", "assistant"}, - {"content", nullptr}, + {"content", caps_.requires_non_null_content? "" : nullptr}, {"tool_calls", tool_calls}, }; }; @@ -195,9 +199,6 @@ class chat_template { caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; - auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); - auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); - caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); if (caps_.supports_tool_calls) { auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); @@ -234,7 +235,7 @@ class chat_template { }; const json tool_call_msg { {"role", "assistant"}, - {"content", nullptr}, + {"content", caps_.requires_non_null_content ? "" : nullptr}, {"tool_calls", json::array({ { // TODO: detect if requires numerical id or fixed length == 6 like Nemo From 264eab4030de87d1b2e0ec4c99850524707f5e91 Mon Sep 17 00:00:00 2001 From: yichunkuo Date: Fri, 11 Jul 2025 09:00:24 +0000 Subject: [PATCH 51/59] Fix null content for chat-template. (#75) The nullptr cause `Test failed: basic_string: construction from null is not valid`. Change to use json j_null instead. Tests pass with `./scripts/tests.sh` --- include/minja/chat-template.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 3bc7c77..8f617db 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -165,11 +165,12 @@ class chat_template { auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); - + + json j_null; auto make_tool_calls_msg = [&](const json & tool_calls) { return json { {"role", "assistant"}, - {"content", caps_.requires_non_null_content? "" : nullptr}, + {"content", caps_.requires_non_null_content? "" : j_null}, {"tool_calls", tool_calls}, }; }; @@ -235,7 +236,7 @@ class chat_template { }; const json tool_call_msg { {"role", "assistant"}, - {"content", caps_.requires_non_null_content ? "" : nullptr}, + {"content", caps_.requires_non_null_content ? "" : j_null}, {"tool_calls", json::array({ { // TODO: detect if requires numerical id or fixed length == 6 like Nemo From 58568621432715b0ed38efd16238b0e7ff36c3ba Mon Sep 17 00:00:00 2001 From: yichunkuo Date: Fri, 11 Jul 2025 09:01:01 +0000 Subject: [PATCH 52/59] Enable SmolLM 3 template. (#74) * Enable SmolLM 3 template. - Add `replace` method for string type and update unit tests. - Support slice expression with omitted start or end index and update unit tests. - Fix the bug that `in` operator does not work for string type. - Update testing script to support fetching `chat-template.jinja` when the template is not in `tokenizer_config.json`. Test passed with SmolLM 3B model: https://huggingface.co/HuggingFaceTB/SmolLM3-3B * Update test cases of string replace. --------- Co-authored-by: Olivier Chafik --- include/minja/minja.hpp | 24 +++++++++++++++++++++--- scripts/fetch_templates_and_goldens.py | 10 +++++++++- tests/CMakeLists.txt | 1 + tests/test-syntax.cpp | 14 ++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index ee123a7..a36e446 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1355,8 +1355,13 @@ class BinaryOpExpr : public Expression { case Op::Gt: return l > r; case Op::Le: return l <= r; case Op::Ge: return l >= r; - case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); - case Op::NotIn: return !(r.is_array() && r.contains(l)); + case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) || + (l.is_string() && r.is_string() && + r.to_str().find(l.to_str()) != std::string::npos)); + case Op::NotIn: + return !(((r.is_array() || r.is_object()) && r.contains(l)) || + (l.is_string() && r.is_string() && + r.to_str().find(l.to_str()) != std::string::npos)); default: break; } throw std::runtime_error("Unknown binary operator"); @@ -1552,6 +1557,19 @@ class MethodCallExpr : public Expression { else res[i] = std::tolower(res[i]); } return res; + } else if (method->get_name() == "replace") { + vargs.expectArgs("replace method", {2, 3}, {0, 0}); + auto before = vargs.args[0].get(); + auto after = vargs.args[1].get(); + auto count = vargs.args.size() == 3 ? vargs.args[2].get() + : str.length(); + size_t start_pos = 0; + while ((start_pos = str.find(before, start_pos)) != std::string::npos && + count-- > 0) { + str.replace(start_pos, before.length(), after); + start_pos += after.length(); + } + return str; } } throw std::runtime_error("Unknown method: " + method->get_name()); @@ -2128,7 +2146,7 @@ class Parser { } } - if ((has_first_colon || has_second_colon) && (start || end || step)) { + if ((has_first_colon || has_second_colon)) { index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); } else { index = std::move(start); diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 6e65099..6950ffe 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -427,7 +427,15 @@ async def process_model(output_folder: str, model_id: str, contexts: list[Contex except json.JSONDecodeError: config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) - assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json!' + if 'chat_template' not in config: + try: + chat_template = await async_hf_download(model_id, "chat_template.jinja") + config.update({'chat_template': chat_template}) + except Exception as e: + logger.error(f"Failed to fetch chat_template.jinja for model {model_id}: {e}") + raise e + + assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json or no chat_template.jinja file found!' chat_template = config['chat_template'] if isinstance(chat_template, str): await handle_chat_template(output_folder, model_id, None, chat_template, contexts) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09323b3..c624d5a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -148,6 +148,7 @@ set(MODEL_IDS HuggingFaceTB/SmolLM2-1.7B-Instruct HuggingFaceTB/SmolLM2-135M-Instruct HuggingFaceTB/SmolLM2-360M-Instruct + HuggingFaceTB/SmolLM3-3B huihui-ai/DeepSeek-R1-Distill-Llama-70B-abliterated huihui-ai/DeepSeek-R1-Distill-Llama-8B-abliterated huihui-ai/DeepSeek-R1-Distill-Qwen-14B-abliterated-v2 diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index a628aa2..1051b81 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -84,6 +84,12 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "Ok", render("{{ 'ok'.capitalize() }}", {}, {})); + EXPECT_EQ("aouiXYZaouiXYZaoui", + render("{{ 'abcXYZabcXYZabc'.replace('bc', 'oui') }}", {}, {})); + EXPECT_EQ("okXYZokXYZabc", + render("{{ 'abcXYZabcXYZabc'.replace('abc', 'ok', 2) }}", {}, {})); + EXPECT_EQ("abcXYZabcXYZabc", + render("{{ 'abcXYZabcXYZabc'.replace('def', 'ok') }}", {}, {})); EXPECT_EQ( "ok", @@ -199,6 +205,10 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "True,False", render(R"({{ 'a' in ["a"] }},{{ 'a' in [] }})", {}, {})); + EXPECT_EQ("True,False", + render(R"({{ 'a' in 'abc' }},{{ 'd' in 'abc' }})", {}, {})); + EXPECT_EQ("False,True", + render(R"({{ 'a' not in 'abc' }},{{ 'd' not in 'abc' }})", {}, {})); EXPECT_EQ( R"([{'a': 1}])", render(R"({{ [{"a": 1}, {"a": 2}, {}] | selectattr("a", "equalto", 1) | list }})", {}, {})); @@ -481,8 +491,8 @@ TEST(SyntaxTest, SimpleCases) { "[1, 2, 3][0, 1][1, 2]", render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {})); EXPECT_EQ( - "123;01;12", - render("{% set x = '0123' %}{{ x[1:] }};{{ x[:2] }};{{ x[1:3] }}", {}, {})); + "123;01;12;0123;0123", + render("{% set x = '0123' %}{{ x[1:] }};{{ x[:2] }};{{ x[1:3] }};{{ x[:] }};{{ x[::] }}", {}, {})); EXPECT_EQ( "[3, 2, 1, 0][3, 2, 1][2, 1, 0][2, 1][0, 2][3, 1][2, 0]", render("{% set x = [0, 1, 2, 3] %}{{ x[::-1] }}{{ x[:0:-1] }}{{ x[2::-1] }}{{ x[2:0:-1] }}{{ x[::2] }}{{ x[::-2] }}{{ x[-2::-2] }}", {}, {})); From 6420c4c0af1da508bdc42488eeaf9a5e2ace4709 Mon Sep 17 00:00:00 2001 From: Yaroslav Tarkan Date: Fri, 8 Aug 2025 00:47:22 +0300 Subject: [PATCH 53/59] Support `llava-hf/llava-1.5-7b-hf` with `.upper()` string method (#76) * Support .upper() and .lower() string methods * Add syntax tests for upper and lower methods * Add llava-hf/llava-1.5-7b-hf to supported models --- include/minja/minja.hpp | 10 ++++++++++ tests/CMakeLists.txt | 1 + tests/test-syntax.cpp | 8 ++++++++ 3 files changed, 19 insertions(+) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index a36e446..9ae377d 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1541,6 +1541,16 @@ class MethodCallExpr : public Expression { } else if (method->get_name() == "capitalize") { vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); return Value(capitalize(str)); + } else if (method->get_name() == "upper") { + vargs.expectArgs("upper method", {0, 0}, {0, 0}); + auto result = str; + std::transform(result.begin(), result.end(), result.begin(), ::toupper); + return Value(result); + } else if (method->get_name() == "lower") { + vargs.expectArgs("lower method", {0, 0}, {0, 0}); + auto result = str; + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return Value(result); } else if (method->get_name() == "endswith") { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c624d5a..491eca0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -168,6 +168,7 @@ set(MODEL_IDS knifeayumu/Cydonia-v1.3-Magnum-v4-22B langgptai/qwen1.5-7b-chat-sa-v0.1 LatitudeGames/Wayfarer-12B + llava-hf/llava-1.5-7b-hf LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct lightblue/DeepSeek-R1-Distill-Qwen-7B-Japanese diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 1051b81..55e9b0b 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -91,6 +91,14 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ("abcXYZabcXYZabc", render("{{ 'abcXYZabcXYZabc'.replace('def', 'ok') }}", {}, {})); + EXPECT_EQ("HELLO WORLD", render("{{ 'hello world'.upper() }}", {}, {})); + EXPECT_EQ("MIXED", render("{{ 'MiXeD'.upper() }}", {}, {})); + EXPECT_EQ("", render("{{ ''.upper() }}", {}, {})); + + EXPECT_EQ("hello world", render("{{ 'HELLO WORLD'.lower() }}", {}, {})); + EXPECT_EQ("mixed", render("{{ 'MiXeD'.lower() }}", {}, {})); + EXPECT_EQ("", render("{{ ''.lower() }}", {}, {})); + EXPECT_EQ( "ok", render("{# Hey\nHo #}{#- Multiline...\nComments! -#}{{ 'ok' }}{# yo #}", {}, {})); From 8cd0d6d224fdc865a718bb5becb7d0d3bc4995ef Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 7 Aug 2025 22:57:47 +0100 Subject: [PATCH 54/59] Fix tests - except QwQ-32B (#77) * Port content handling of https://github.com/google/minja/pull/72 to python test logic * Pin dependency versions * add clangd/ to .gitignore * disable QwQ-32B test --- .gitignore | 3 ++- CMakeLists.txt | 6 +++--- scripts/fetch_templates_and_goldens.py | 12 ++++++------ tests/CMakeLists.txt | 2 +- tests/test-capabilities.cpp | 14 +++++++++++++- 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 4049566..5f91bce 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ dist/ .DS_Store Testing/ .vscode/ -__pycache__/ \ No newline at end of file +__pycache__/ +clangd/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a3760c..95dabe7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,20 +55,20 @@ endif() include(FetchContent) # Fetch nlohmann/json -FetchContent_Declare(json URL https://github.com/nlohmann/json/archive/refs/heads/develop.zip) +FetchContent_Declare(json URL https://github.com/nlohmann/json/archive/refs/tags/v3.12.0.zip) FetchContent_MakeAvailable(json) target_link_libraries(minja INTERFACE nlohmann_json::nlohmann_json) if(MINJA_TEST_ENABLED) if (MINJA_FUZZTEST_ENABLED) # Fetch google/fuzztest (and indirectly, gtest) - FetchContent_Declare(fuzztest URL https://github.com/google/fuzztest/archive/refs/heads/main.zip) + FetchContent_Declare(fuzztest URL https://github.com/google/fuzztest/archive/refs/tags/2025-08-05.zip) FetchContent_MakeAvailable(fuzztest) message(STATUS "${fuzztest_BINARY_DIR}: ${${fuzztest_BINARY_DIR}}") else() # Fetch gtest set(INSTALL_GTEST OFF) - FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/refs/heads/main.zip) + FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip) FetchContent_MakeAvailable(googletest) endif() endif() diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 6950ffe..573f783 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -166,10 +166,14 @@ def __init__(self, template, env=None, filters=None, global_functions=None): }]) caps.supports_tools = "some_tool" in out + caps.requires_non_null_content = \ + (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ + and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) + def make_tool_calls_msg(tool_calls, content=None): return { "role": "assistant", - "content": content, + "content": "" if content is None and caps.requires_non_null_content else content, "tool_calls": tool_calls, } def make_tool_call(tool_name, arguments): @@ -198,10 +202,6 @@ def make_tool_call(tool_name, arguments): caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - caps.requires_non_null_content = \ - (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ - and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) - if caps.supports_tool_calls: dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) tc1 = make_tool_call("test_tool1", dummy_args) @@ -232,7 +232,7 @@ def make_tool_call(tool_name, arguments): args = {"arg1": "some_value"} tool_call_msg = { "role": "assistant", - "content": None, + "content": "" if caps.requires_non_null_content else None, "tool_calls": [ { "id": "call_1___", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 491eca0..14e789b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -285,7 +285,6 @@ set(MODEL_IDS Qwen/Qwen2.5-VL-3B-Instruct Qwen/Qwen2.5-VL-72B-Instruct Qwen/Qwen2.5-VL-7B-Instruct - Qwen/QwQ-32B Qwen/QwQ-32B-Preview rubenroy/Zurich-14B-GCv2-5m rubenroy/Zurich-7B-GCv2-5m @@ -323,6 +322,7 @@ set(MODEL_IDS Qwen/Qwen3-4B # Broken, TODO: + # Qwen/QwQ-32B # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 # Almawave/Velvet-14B # deepseek-ai/DeepSeek-R1 diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp index 53e3999..88968ce 100644 --- a/tests/test-capabilities.cpp +++ b/tests/test-capabilities.cpp @@ -75,6 +75,18 @@ TEST(CapabilitiesTest, Gemma7b) { EXPECT_FALSE(caps.requires_typed_content); } +TEST(CapabilitiesTest, QwQ32B) { + auto caps = get_caps("tests/Qwen-QwQ-32B.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + #ifndef _WIN32 TEST(CapabilitiesTest, DeepSeekR1Distill) { @@ -141,7 +153,7 @@ TEST(CapabilitiesTest, MetaLlama3_3_70BInstruct) { TEST(CapabilitiesTest, MiniMaxAIText01) { auto caps = get_caps("tests/MiniMaxAI-MiniMax-Text-01.jinja"); EXPECT_TRUE(caps.supports_system_role); - EXPECT_FALSE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tools); EXPECT_FALSE(caps.supports_tool_calls); EXPECT_FALSE(caps.supports_tool_responses); EXPECT_FALSE(caps.supports_parallel_tool_calls); From b04424943e2d70ba94ff01a53177e0279276a26c Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 7 Aug 2025 23:38:49 +0100 Subject: [PATCH 55/59] qwen3 coder support (in filter, keys method) (#78) --- include/minja/minja.hpp | 25 ++++++++++++++++++------- tests/CMakeLists.txt | 3 +++ tests/test-syntax.cpp | 4 ++++ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 9ae377d..8160a72 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1291,6 +1291,12 @@ class UnaryOpExpr : public Expression { } }; +static bool in(const Value & value, const Value & container) { + return (((container.is_array() || container.is_object()) && container.contains(value)) || + (value.is_string() && container.is_string() && + container.to_str().find(value.to_str()) != std::string::npos)); +}; + class BinaryOpExpr : public Expression { public: enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; @@ -1355,13 +1361,8 @@ class BinaryOpExpr : public Expression { case Op::Gt: return l > r; case Op::Le: return l <= r; case Op::Ge: return l >= r; - case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) || - (l.is_string() && r.is_string() && - r.to_str().find(l.to_str()) != std::string::npos)); - case Op::NotIn: - return !(((r.is_array() || r.is_object()) && r.contains(l)) || - (l.is_string() && r.is_string() && - r.to_str().find(l.to_str()) != std::string::npos)); + case Op::In: return in(l, r); + case Op::NotIn: return !in(l, r); default: break; } throw std::runtime_error("Unknown binary operator"); @@ -1500,6 +1501,13 @@ class MethodCallExpr : public Expression { } else if (method->get_name() == "pop") { vargs.expectArgs("pop method", {1, 1}, {0, 0}); return obj.pop(vargs.args[0]); + } else if (method->get_name() == "keys") { + vargs.expectArgs("keys method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value(key)); + } + return result; } else if (method->get_name() == "get") { vargs.expectArgs("get method", {1, 2}, {0, 0}); auto key = vargs.args[0]; @@ -2792,6 +2800,9 @@ inline std::shared_ptr Context::builtins() { if (!items.is_array()) throw std::runtime_error("object is not iterable"); return items; })); + globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr &, Value & args) -> Value { + return in(args.at("item"), args.at("items")); + })); globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { auto & items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not iterable"); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 14e789b..3dca28b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -320,6 +320,9 @@ set(MODEL_IDS xwen-team/Xwen-72B-Chat xwen-team/Xwen-7B-Chat Qwen/Qwen3-4B + Qwen/Qwen3-235B-A22B-Instruct-2507 + Qwen/Qwen3-235B-A22B-Thinking-2507 + Qwen/Qwen3-Coder-30B-A3B-Instruct # Broken, TODO: # Qwen/QwQ-32B diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 55e9b0b..dfef761 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -217,6 +217,10 @@ TEST(SyntaxTest, SimpleCases) { render(R"({{ 'a' in 'abc' }},{{ 'd' in 'abc' }})", {}, {})); EXPECT_EQ("False,True", render(R"({{ 'a' not in 'abc' }},{{ 'd' not in 'abc' }})", {}, {})); + EXPECT_EQ("['a', 'a']", + render(R"({{ ['a', 'b', 'c', 'a'] | select('in', ['a']) | list }})", {}, {})); + EXPECT_EQ("['a', 'b'],[]", + render(R"({{ {'a': 1, 'b': 2}.keys() | list }},{{ {}.keys() | list }})", {}, {})); EXPECT_EQ( R"([{'a': 1}])", render(R"({{ [{"a": 1}, {"a": 2}, {}] | selectattr("a", "equalto", 1) | list }})", {}, {})); From bc6cec9b7171464cbff01f266169490d79d965c2 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 7 Aug 2025 23:39:32 +0100 Subject: [PATCH 56/59] fix Qwen/QwQ-32B (#79) --- include/minja/chat-template.hpp | 11 +++++++++-- tests/CMakeLists.txt | 2 +- tests/test-capabilities.cpp | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 8f617db..fc9b7d4 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -162,8 +162,15 @@ class chat_template { }), false); caps_.supports_tools = contains(out, "some_tool"); - auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); - auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + const auto render_with_content = [&](const json & content) { + const json assistant_msg {{"role", "assistant"}, {"content", content}}; + // Render two assistant messages as some templates like QwQ-32B are handling + // the content differently depending on whether it's the last message or not + // (to remove the tag in all but the last message). + return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false); + }; + auto out_empty = render_with_content(""); + auto out_null = render_with_content(json()); caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); json j_null; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3dca28b..0388a74 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -323,9 +323,9 @@ set(MODEL_IDS Qwen/Qwen3-235B-A22B-Instruct-2507 Qwen/Qwen3-235B-A22B-Thinking-2507 Qwen/Qwen3-Coder-30B-A3B-Instruct + Qwen/QwQ-32B # Broken, TODO: - # Qwen/QwQ-32B # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 # Almawave/Velvet-14B # deepseek-ai/DeepSeek-R1 diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp index 88968ce..1d43c62 100644 --- a/tests/test-capabilities.cpp +++ b/tests/test-capabilities.cpp @@ -80,7 +80,7 @@ TEST(CapabilitiesTest, QwQ32B) { EXPECT_TRUE(caps.supports_system_role); EXPECT_TRUE(caps.supports_tools); EXPECT_TRUE(caps.supports_tool_calls); - EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_tool_responses); EXPECT_TRUE(caps.supports_parallel_tool_calls); EXPECT_TRUE(caps.requires_object_arguments); // EXPECT_TRUE(caps.requires_non_null_content); From 5be6f88a648570b26341bb008e686f3b64c2f4ac Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 8 Aug 2025 02:09:48 +0100 Subject: [PATCH 57/59] Fix qwen3 coder capabilities detection (#80) * fix items filter * test qwen3coder capabilities * fix detection of tool calls and obj args requirement for qwen3 coder --- include/minja/chat-template.hpp | 4 +- include/minja/minja.hpp | 14 +-- scripts/fetch_templates_and_goldens.py | 4 +- tests/test-capabilities.cpp | 12 ++ tests/test-syntax.cpp | 7 +- .../Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja | 117 ++++++++++++++++++ 6 files changed, 142 insertions(+), 16 deletions(-) create mode 100644 tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index fc9b7d4..d31fb90 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -198,12 +198,12 @@ class chat_template { dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), }), {}, false); - auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + auto tool_call_renders_str_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), }), {}, false); - auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + auto tool_call_renders_obj_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 8160a72..f04073c 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2664,15 +2664,11 @@ inline std::shared_ptr Context::builtins() { auto items = Value::array(); if (args.contains("object")) { auto & obj = args.at("object"); - if (obj.is_string()) { - auto json_obj = json::parse(obj.get()); - for (const auto & kv : json_obj.items()) { - items.push_back(Value::array({kv.key(), kv.value()})); - } - } else if (!obj.is_null()) { - for (auto & key : obj.keys()) { - items.push_back(Value::array({key, obj.at(key)})); - } + if (!obj.is_object()) { + throw std::runtime_error("Can only get item pairs from a mapping"); + } + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); } } return items; diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 573f783..acaf969 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -192,12 +192,12 @@ def make_tool_call(tool_name, arguments): dummy_user_msg, make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), ]) - tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out + tool_call_renders_str_arguments = "" in out or '"argument_needle":' in out or "'argument_needle':" in out out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) - tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out + tool_call_renders_obj_arguments = "" in out or '"argument_needle":' in out or "'argument_needle':" in out caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp index 1d43c62..458f9b9 100644 --- a/tests/test-capabilities.cpp +++ b/tests/test-capabilities.cpp @@ -87,6 +87,18 @@ TEST(CapabilitiesTest, QwQ32B) { EXPECT_FALSE(caps.requires_typed_content); } +TEST(CapabilitiesTest, Qwen3Coder) { + auto caps = get_caps("tests/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + #ifndef _WIN32 TEST(CapabilitiesTest, DeepSeekR1Distill) { diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index dfef761..b4bf638 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -430,9 +430,6 @@ TEST(SyntaxTest, SimpleCases) { {{- foo() }} {{ foo() -}})", {}, {})); if (!getenv("USE_JINJA2")) { - EXPECT_EQ( - "[]", - render(R"({{ None | items | list | tojson }})", {}, {})); EXPECT_EQ( "Foo", render(R"({% generation %}Foo{% endgeneration %})", {}, {})); @@ -561,6 +558,10 @@ TEST(SyntaxTest, SimpleCases) { if (!getenv("USE_JINJA2")) { // TODO: capture stderr from jinja2 and test these. + EXPECT_THAT([]() { render("{{ '' | items }}", {}, {}); }, ThrowsWithSubstr("Can only get item pairs from a mapping")); + EXPECT_THAT([]() { render("{{ [] | items }}", {}, {}); }, ThrowsWithSubstr("Can only get item pairs from a mapping")); + EXPECT_THAT([]() { render("{{ None | items }}", {}, {}); }, ThrowsWithSubstr("Can only get item pairs from a mapping")); + EXPECT_THAT([]() { render("{% break %}", {}, {}); }, ThrowsWithSubstr("break outside of a loop")); EXPECT_THAT([]() { render("{% continue %}", {}, {}); }, ThrowsWithSubstr("continue outside of a loop")); diff --git a/tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja b/tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja new file mode 100644 index 0000000..e539012 --- /dev/null +++ b/tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} From 3e4c61c616eda133cfb1e440fc7a14bf1729bbee Mon Sep 17 00:00:00 2001 From: Yaroslav Tarkan Date: Tue, 9 Sep 2025 20:46:24 +0300 Subject: [PATCH 58/59] Support `call`/`endcall` blocks (#81) * Add call/endcall support * Add syntax tests for call blocks * Add call blocks to supported features in readme * Add openbmb/MiniCPM3-4B to test models * Remove non-existent model * Add tests for unterminated call and macro --- README.md | 3 +- include/minja/minja.hpp | 86 ++++++++++++++++++++++++++++++++++++----- tests/CMakeLists.txt | 2 +- tests/test-syntax.cpp | 56 +++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index aa3579c..a33bc36 100644 --- a/README.md +++ b/README.md @@ -104,10 +104,11 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji - Full expression syntax - Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}` - `if` / `elif` / `else` / `endif` -- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring +- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring) - `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39)) - `set` w/ namespaces & destructuring - `macro` / `endmacro` +- `call` / `endcall` - for calling macro (w/ macro arguments and `caller()` syntax) and passing a macro to another macro (w/o passing arguments back to the call block) - `filter` / `endfilter` - Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim` diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index f04073c..5ed0556 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -706,7 +706,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall }; static std::string typeToString(Type t) { switch (t) { @@ -729,6 +729,8 @@ class TemplateToken { case Type::EndGeneration: return "endgeneration"; case Type::Break: return "break"; case Type::Continue: return "continue"; + case Type::Call: return "call"; + case Type::EndCall: return "endcall"; } return "Unknown"; } @@ -846,6 +848,17 @@ struct LoopControlTemplateToken : public TemplateToken { LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} }; +struct CallTemplateToken : public TemplateToken { + std::shared_ptr expr; + CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) + : TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {} +}; + +struct EndCallTemplateToken : public TemplateToken { + EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) + : TemplateToken(Type::EndCall, loc, pre, post) {} +}; + class TemplateNode { Location location_; protected: @@ -1050,31 +1063,36 @@ class MacroNode : public TemplateNode { void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { if (!name) throw std::runtime_error("MacroNode.name is null"); if (!body) throw std::runtime_error("MacroNode.body is null"); - auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { - auto call_context = macro_context; + auto callable = Value::callable([this, macro_context](const std::shared_ptr & call_context, ArgumentsValue & args) { + auto execution_context = Context::make(Value::object(), macro_context); + + if (call_context->contains("caller")) { + execution_context->set("caller", call_context->get("caller")); + } + std::vector param_set(params.size(), false); for (size_t i = 0, n = args.args.size(); i < n; i++) { auto & arg = args.args[i]; if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); param_set[i] = true; auto & param_name = params[i].first; - call_context->set(param_name, arg); + execution_context->set(param_name, arg); } for (auto & [arg_name, value] : args.kwargs) { auto it = named_param_positions.find(arg_name); if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); - call_context->set(arg_name, value); + execution_context->set(arg_name, value); param_set[it->second] = true; } // Set default values for parameters that were not passed for (size_t i = 0, n = params.size(); i < n; i++) { if (!param_set[i] && params[i].second != nullptr) { - auto val = params[i].second->evaluate(context); - call_context->set(params[i].first, val); + auto val = params[i].second->evaluate(call_context); + execution_context->set(params[i].first, val); } } - return body->render(call_context); + return body->render(execution_context); }); macro_context->set(name->get_name(), callable); } @@ -1611,6 +1629,40 @@ class CallExpr : public Expression { } }; +class CallNode : public TemplateNode { + std::shared_ptr expr; + std::shared_ptr body; + +public: + CallNode(const Location & loc, std::shared_ptr && e, std::shared_ptr && b) + : TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("CallNode.expr is null"); + if (!body) throw std::runtime_error("CallNode.body is null"); + + auto caller = Value::callable([this, context](const std::shared_ptr &, ArgumentsValue &) -> Value { + return Value(body->render(context)); + }); + + context->set("caller", caller); + + auto call_expr = dynamic_cast(expr.get()); + if (!call_expr) { + throw std::runtime_error("Invalid call block syntax - expected function call"); + } + + Value function = call_expr->object->evaluate(context); + if (!function.is_callable()) { + throw std::runtime_error("Call target must be callable: " + function.dump()); + } + ArgumentsValue args = call_expr->args.evaluate(context); + + Value result = function.call(context, args); + out << result.to_str(); + } +}; + class FilterExpr : public Expression { std::vector> parts; public: @@ -2320,7 +2372,7 @@ class Parser { static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); static std::regex block_close_regex(R"(\s*([-~])?%\})"); @@ -2443,6 +2495,15 @@ class Parser { } else if (keyword == "endmacro") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "call") { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (keyword == "endcall") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "filter") { auto filter = parseExpression(); if (!filter) throw std::runtime_error("Expected expression in filter block"); @@ -2575,6 +2636,12 @@ class Parser { throw unterminated(**start); } children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto call_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(call_token->expr), std::move(body))); } else if (auto filter_token = dynamic_cast(token.get())) { auto body = parseTemplate(begin, it, end); if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { @@ -2588,6 +2655,7 @@ class Parser { } else if (dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) + || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0388a74..db82c2d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -226,6 +226,7 @@ set(MODEL_IDS OnlyCheeini/greesychat-turbo onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX open-thoughts/OpenThinker-7B + openbmb/MiniCPM3-4B openchat/openchat-3.5-0106 Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2 OrionStarAI/Orion-14B-Chat @@ -261,7 +262,6 @@ set(MODEL_IDS prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M prithivMLmods/QwQ-Math-IO-500M prithivMLmods/Triangulum-v2-10B - qingy2024/Falcon3-2x10B-MoE-Instruct Qwen/QVQ-72B-Preview Qwen/Qwen1.5-7B-Chat Qwen/Qwen2-7B-Instruct diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index b4bf638..36bdaa3 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -429,6 +429,54 @@ TEST(SyntaxTest, SimpleCases) { {%- endmacro -%} {{- foo() }} {{ foo() -}})", {}, {})); + EXPECT_EQ( + "x,x", + render(R"( + {%- macro test() -%}{{ caller() }},{{ caller() }}{%- endmacro -%} + {%- call test() -%}x{%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "Outer[Inner(X)]", + render(R"( + {%- macro outer() -%}Outer[{{ caller() }}]{%- endmacro -%} + {%- macro inner() -%}Inner({{ caller() }}){%- endmacro -%} + {%- call outer() -%}{%- call inner() -%}X{%- endcall -%}{%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "
  • A
  • B
", + render(R"( + {%- macro test(prefix, suffix) -%}{{ prefix }}{{ caller() }}{{ suffix }}{%- endmacro -%} + {%- set items = ["a", "b"] -%} + {%- call test("
    ", "
") -%} + {%- for item in items -%} +
  • {{ item | upper }}
  • + {%- endfor -%} + {%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "\\n\\nclass A:\\n b: 1\\n c: 2\\n", + render(R"( + {%- macro recursive(obj) -%} + {%- set ns = namespace(content = caller()) -%} + {%- for key, value in obj.items() %} + {%- if value is mapping %} + {%- call recursive(value) -%} + {{ '\\n\\nclass ' + key.title() + ':\\n' }} + {%- endcall -%} + {%- else -%} + {%- set ns.content = ns.content + ' ' + key + ': ' + value + '\\n' -%} + {%- endif -%} + {%- endfor -%} + {{ ns.content }} + {%- endmacro -%} + + {%- call recursive({"a": {"b": "1", "c": "2"}}) -%} + {%- endcall -%} + )", {}, {})); + if (!getenv("USE_JINJA2")) { EXPECT_EQ( "Foo", @@ -576,6 +624,8 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected elif")); EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfor")); EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfilter")); + EXPECT_THAT([]() { render("{% endmacro %}", {}, {}); }, ThrowsWithSubstr("Unexpected endmacro")); + EXPECT_THAT([]() { render("{% endcall %}", {}, {}); }, ThrowsWithSubstr("Unexpected endcall")); EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for")); @@ -584,6 +634,12 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter")); EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag")); + EXPECT_THAT([]() { render("{% macro test() %}", {}, {}); }, ThrowsWithSubstr("Unterminated macro")); + EXPECT_THAT([]() { render("{% call test %}", {}, {}); }, ThrowsWithSubstr("Unterminated call")); + + EXPECT_THAT([]() { + render("{%- macro test() -%}content{%- endmacro -%}{%- call test -%}caller_content{%- endcall -%}", {}, {}); + }, ThrowsWithSubstr("Invalid call block syntax - expected function call")); } EXPECT_EQ( From 021c2293c187789ef13d56c6cfd89c9b134fd80f Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 22 Sep 2025 18:33:38 +0100 Subject: [PATCH 59/59] Used by Docker Model Runner (#85) For jinja templates Signed-off-by: Eric Curtin --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a33bc36..5981079 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ _**This is not an official Google product**_ -Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016), [Jan](https://jan.ai/) (through [cortex.cpp](https://github.com/menloresearch/cortex.cpp/pull/1814)) and [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433)). +Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016), [Jan](https://jan.ai/) (through [cortex.cpp](https://github.com/menloresearch/cortex.cpp/pull/1814)), [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433) and [Docker Model Runner](https://github.com/docker/model-runner)). It is **not general purpose**: it includes just what’s needed for actual chat templates (very limited set of filters, tests and language features). Users with different needs should look at third-party alternatives such as [Jinja2Cpp](https://github.com/jinja2cpp/Jinja2Cpp), [Jinja2CppLight](https://github.com/hughperkins/Jinja2CppLight), or [inja](https://github.com/pantor/inja) (none of which we endorse).