diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d73d9b9..5f7fbdb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -52,7 +52,8 @@ jobs: - name: Clone uses: actions/checkout@v4 with: - fetch-depth: 0 + fetch-depth: 1 + single-branch: true - name: ccache uses: hendrikmuhs/ccache-action@v1.2.11 diff --git a/.gitignore b/.gitignore index 4049566..5f91bce 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ dist/ .DS_Store Testing/ .vscode/ -__pycache__/ \ No newline at end of file +__pycache__/ +clangd/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index bd11142..95dabe7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,8 +11,26 @@ cmake_policy(SET CMP0135 NEW) # https://cmake.org/cmake/help/latest/policy/CMP01 project(minja VERSION 1.0.0 LANGUAGES CXX) +add_library(minja INTERFACE) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# Test if clang-tidy is available +find_program(CLANG_TIDY_EXE NAMES "clang-tidy") +if (CLANG_TIDY_EXE) + message(STATUS "clang-tidy found: ${CLANG_TIDY_EXE}") + set(CMAKE_CXX_CLANG_TIDY + clang-tidy; + -header-filter=include/minja/.*; + # https://clang.llvm.org/extra/clang-tidy/checks/list.html + # TODO: enable more / disable less checks: google-*,misc-*,modernize-*,performance-* + -checks=-*,clang-analyzer-*,clang-diagnostic-*,cppcoreguideline-*,bugprone-*,-bugprone-suspicious-include,-bugprone-assignment-in-if-condition,-bugprone-narrowing-conversions,-bugprone-easily-swappable-parameters,-bugprone-inc-dec-in-conditions,-bugprone-exception-escape,-clang-analyzer-cplusplus.StringChecker; + -warnings-as-errors=*; + ) +else() + message(STATUS "clang-tidy not found") +endif() + if (MSVC) set(MINJA_FUZZTEST_ENABLED_DEFAULT OFF) set(MINJA_USE_VENV_DEFAULT OFF) @@ -20,6 +38,8 @@ else() set(MINJA_FUZZTEST_ENABLED_DEFAULT ON) set(MINJA_USE_VENV_DEFAULT ON) endif() +option(MINJA_TEST_ENABLED "minja: Build with test(python interpreter required)" ON) +option(MINJA_EXAMPLE_ENABLED "minja: Build with example" ON) option(MINJA_FUZZTEST_ENABLED "minja: fuzztests enabled" MINJA_FUZZTEST_ENABLED_DEFAULT) option(MINJA_FUZZTEST_FUZZING_MODE "minja: run fuzztests (if enabled) in fuzzing mode" OFF) option(MINJA_USE_VENV "minja: use Python venv for build" MINJA_USE_VENV_DEFAULT) @@ -35,17 +55,22 @@ endif() include(FetchContent) # Fetch nlohmann/json -FetchContent_Declare(json URL https://github.com/nlohmann/json/archive/refs/heads/develop.zip) +FetchContent_Declare(json URL https://github.com/nlohmann/json/archive/refs/tags/v3.12.0.zip) FetchContent_MakeAvailable(json) - -if (MINJA_FUZZTEST_ENABLED) - # Fetch google/fuzztest (and indirectly, gtest) - FetchContent_Declare(fuzztest URL https://github.com/google/fuzztest/archive/refs/heads/main.zip) - FetchContent_MakeAvailable(fuzztest) -else() - # Fetch gtest - FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/refs/heads/main.zip) - FetchContent_MakeAvailable(googletest) +target_link_libraries(minja INTERFACE nlohmann_json::nlohmann_json) + +if(MINJA_TEST_ENABLED) + if (MINJA_FUZZTEST_ENABLED) + # Fetch google/fuzztest (and indirectly, gtest) + FetchContent_Declare(fuzztest URL https://github.com/google/fuzztest/archive/refs/tags/2025-08-05.zip) + FetchContent_MakeAvailable(fuzztest) + message(STATUS "${fuzztest_BINARY_DIR}: ${${fuzztest_BINARY_DIR}}") + else() + # Fetch gtest + set(INSTALL_GTEST OFF) + FetchContent_Declare(googletest URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip) + FetchContent_MakeAvailable(googletest) + endif() endif() # Use ccache if installed @@ -61,25 +86,27 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() -set(Python_FIND_STRATEGY LOCATION CACHE STRING "Python find strategy" FORCE) -find_package(Python COMPONENTS Interpreter REQUIRED) -if(MINJA_USE_VENV) - # Create a python venv w/ the required dependencies - set(VENV_DIR "${CMAKE_BINARY_DIR}/venv") - if(WIN32) - set(VENV_PYTHON "${VENV_DIR}/Scripts/python.exe") - else() - set(VENV_PYTHON "${VENV_DIR}/bin/python") +if(MINJA_TEST_ENABLED) + set(Python_FIND_STRATEGY LOCATION CACHE STRING "Python find strategy" FORCE) + find_package(Python COMPONENTS Interpreter REQUIRED) + if(MINJA_USE_VENV) + # Create a python venv w/ the required dependencies + set(VENV_DIR "${CMAKE_BINARY_DIR}/venv") + if(WIN32) + set(VENV_PYTHON "${VENV_DIR}/Scripts/python.exe") + else() + set(VENV_PYTHON "${VENV_DIR}/bin/python") + endif() + execute_process( + COMMAND ${Python_EXECUTABLE} -m venv "${VENV_DIR}" + COMMAND_ERROR_IS_FATAL ANY) + execute_process( + COMMAND ${VENV_PYTHON} -m pip install -r "${CMAKE_SOURCE_DIR}/requirements.txt" + COMMAND_ERROR_IS_FATAL ANY) + set(Python_EXECUTABLE "${VENV_PYTHON}" CACHE FILEPATH "Path to Python executable in venv" FORCE) endif() - execute_process( - COMMAND ${Python_EXECUTABLE} -m venv "${VENV_DIR}" - COMMAND_ERROR_IS_FATAL ANY) - execute_process( - COMMAND ${VENV_PYTHON} -m pip install -r "${CMAKE_SOURCE_DIR}/requirements.txt" - COMMAND_ERROR_IS_FATAL ANY) - set(Python_EXECUTABLE "${VENV_PYTHON}" CACHE FILEPATH "Path to Python executable in venv" FORCE) + message(STATUS "Python executable: ${Python_EXECUTABLE}") endif() -message(STATUS "Python executable: ${Python_EXECUTABLE}") find_program(CPPCHECK cppcheck) if(CPPCHECK) @@ -87,15 +114,29 @@ if(CPPCHECK) message(STATUS "cppcheck found: ${CPPCHECK}") endif() -message(STATUS "${fuzztest_BINARY_DIR}: ${${fuzztest_BINARY_DIR}}") -include_directories( - include/minja - ${json_SOURCE_DIR}/include - ${json_SOURCE_DIR}/include/nlohmann +include(GNUInstallDirs) +target_include_directories(minja INTERFACE + $ + $ ) -add_subdirectory(examples) +install(FILES + ${PROJECT_SOURCE_DIR}/include/minja/minja.hpp + ${PROJECT_SOURCE_DIR}/include/minja/chat-template.hpp + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/minja +) +install( + TARGETS minja + EXPORT "${TARGETS_EXPORT_NAME}" + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/minja # for downstream projects +) -enable_testing() -include(GoogleTest) -add_subdirectory(tests) +if(MINJA_EXAMPLE_ENABLED) + add_subdirectory(examples) +endif() + +if(MINJA_TEST_ENABLED) + enable_testing() + include(GoogleTest) + add_subdirectory(tests) +endif() diff --git a/README.md b/README.md index d6ffc7c..5981079 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,16 @@ _**This is not an official Google product**_ -Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (such as [llama.cpp](https://github.com/ggerganov/llama.cpp) or [gemma.cpp](https://github.com/google/gemma.cpp)). +Minja is a minimalistic reimplementation of the [Jinja](https://github.com/pallets/jinja/) templating engine to integrate in/with C++ LLM projects (it's used in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/11016), [Jan](https://jan.ai/) (through [cortex.cpp](https://github.com/menloresearch/cortex.cpp/pull/1814)), [GPT4All](https://github.com/nomic-ai/gpt4all/pull/3433) and [Docker Model Runner](https://github.com/docker/model-runner)). It is **not general purpose**: it includes just what’s needed for actual chat templates (very limited set of filters, tests and language features). Users with different needs should look at third-party alternatives such as [Jinja2Cpp](https://github.com/jinja2cpp/Jinja2Cpp), [Jinja2CppLight](https://github.com/hughperkins/Jinja2CppLight), or [inja](https://github.com/pantor/inja) (none of which we endorse). > [!WARNING] > TL;DR: use of Minja is *at your own risk*, and the risks are plenty! See [Security & Privacy](#security--privacy) section below. +> [!IMPORTANT] +> [@ochafik](https://github.com/ochafik) has left Google, watch out for https://github.com/ochafik/minja + [![CI](https://github.com/google/minja/actions/workflows/build.yml/badge.svg)](https://github.com/google/minja/actions/workflows/build.yml) ## Design goals: @@ -31,9 +34,17 @@ It is **not general purpose**: it includes just what’s needed for actual chat ## Usage: -This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++11 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json)'s `json.hpp` in your include path. +This library is header-only: just copy the header(s) you need, make sure to use a compiler that handles C++17 and you're done. Oh, and get [nlohmann::json](https://github.com/nlohmann/json) in your include path. + +If your project is based on [cmake](https://cmake.org/), can simply import by using `FetchContent`. +``` +FetchContent_Declare(minja GIT_REPOSITORY "https://github.com/google/minja") +FetchContent_MakeAvailable(minja) + +target_link_libraries( PRIVATE minja) +``` -See API in [minja/minja.hpp](./include/minja/minja.hpp) and [minja/chat-template.h](./include/minja/chat-template.hpp) (experimental). +See API in [minja/minja.hpp](./include/minja/minja.hpp) and [minja/chat-template.hpp](./include/minja/chat-template.hpp) (experimental). For raw Jinja templating (see [examples/raw.cpp](./examples/raw.cpp)): @@ -93,11 +104,13 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji - Full expression syntax - Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}` - `if` / `elif` / `else` / `endif` -- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring +- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring) +- `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39)) - `set` w/ namespaces & destructuring - `macro` / `endmacro` +- `call` / `endcall` - for calling macro (w/ macro arguments and `caller()` syntax) and passing a macro to another macro (w/o passing arguments back to the call block) - `filter` / `endfilter` -- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject`, `tojson`, `trim` +- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim` Main limitations (non-exhaustive list): @@ -110,8 +123,10 @@ Main limitations (non-exhaustive list): ## Roadmap / TODOs -- [x] Fix known issues w/ CRLF on Windows -- [ ] Integrate to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/11016 + https://github.com/ggerganov/llama.cpp/pull/9639 +- [ ] Fix known line difference issues on Windows +- [ ] Document the various capabilities detectors + backfill strategies used +- [ ] Propose integration w/ https://github.com/google/gemma.cpp +- [x] Integrate to llama.cpp: https://github.com/ggerganov/llama.cpp/pull/11016 + https://github.com/ggerganov/llama.cpp/pull/9639 - Improve fuzzing coverage: - use thirdparty jinja grammar to guide exploration of inputs (or implement prettification of internal ASTs and use them to generate arbitrary values) - fuzz each filter / test @@ -154,7 +169,7 @@ Main limitations (non-exhaustive list): huggingface-cli login ``` -- Build & run tests (shorthand: `./scripts/run_tests.sh`): +- Build & run tests (shorthand: `./scripts/tests.sh`): ```bash rm -fR build && \ @@ -163,6 +178,8 @@ Main limitations (non-exhaustive list): ctest --test-dir build -j --output-on-failure ``` +- Bonus: install `clang-tidy` before building (on MacOS: `brew install llvm ; sudo ln -s "$(brew --prefix llvm)/bin/clang-tidy" "/usr/local/bin/clang-tidy"`) + - Fuzzing tests - Note: `fuzztest` **[doesn't work](https://github.com/google/fuzztest/issues/179)** natively on Windows or MacOS. @@ -192,7 +209,7 @@ Main limitations (non-exhaustive list): - Build in [fuzzing mode](https://github.com/google/fuzztest/blob/main/doc/quickstart-cmake.md#fuzzing-mode) & run all fuzzing tests (optionally, set a higher `TIMEOUT` as env var): ```bash - ./scripts/run_fuzzing_mode.sh + ./scripts/fuzzing_tests.sh ``` - If your model's template doesn't run fine, please consider the following before [opening a bug](https://github.com/googlestaging/minja/issues/new): diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 750bdea..16e8fb2 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,6 +11,9 @@ foreach(example ) add_executable(${example} ${example}.cpp) target_compile_features(${example} PUBLIC cxx_std_17) - target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json) + target_link_libraries(${example} PRIVATE minja) + if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(${example} PUBLIC _CRT_SECURE_NO_WARNINGS) + endif() endforeach() diff --git a/examples/chat-template.cpp b/examples/chat-template.cpp index 8161b2b..792a157 100644 --- a/examples/chat-template.cpp +++ b/examples/chat-template.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include +#include #include using json = nlohmann::ordered_json; @@ -19,14 +19,16 @@ int main() { /* bos_token= */ "<|start|>", /* eos_token= */ "<|end|>" ); - std::cout << tmpl.apply( - json::parse(R"([ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"} - ])"), - json::parse(R"([ - {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} - ])"), - /* add_generation_prompt= */ true, - /* extra_context= */ {}) << std::endl; + + minja::chat_template_inputs inputs; + inputs.messages = json::parse(R"([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"} + ])"); + inputs.add_generation_prompt = true; + inputs.tools = json::parse(R"([ + {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} + ])"); + + std::cout << tmpl.apply(inputs) << std::endl; } diff --git a/examples/raw.cpp b/examples/raw.cpp index d36a03a..c129449 100644 --- a/examples/raw.cpp +++ b/examples/raw.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include +#include #include using json = nlohmann::ordered_json; diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 8e05652..d31fb90 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -9,10 +9,21 @@ #pragma once #include "minja.hpp" -#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include +#include + using json = nlohmann::ordered_json; namespace minja { @@ -36,12 +47,24 @@ struct chat_template_caps { struct chat_template_inputs { nlohmann::ordered_json messages; nlohmann::ordered_json tools; - bool add_generation_prompt; + bool add_generation_prompt = true; nlohmann::ordered_json extra_context; - // Epoch time in milliseconds. - uint64_t now; - // Timezone offset in minutes. - int64_t timezone_offset; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +struct chat_template_options { + bool apply_polyfills = true; + bool use_bos_token = true; + bool use_eos_token = true; + bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; }; class chat_template { @@ -52,6 +75,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -60,15 +84,18 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - chat_template_inputs inputs { - messages, - tools, - add_generation_prompt, - extra_context, - /* now= */ 0, - /* timezone_offset= */ 0, - }; - auto prompt = apply(inputs, /* adjust_inputs= */ false); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + // Use fixed date for tests + inputs.now = std::chrono::system_clock::from_time_t(0); + + chat_template_options opts; + opts.apply_polyfills = false; + + auto prompt = apply(inputs, opts); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -135,10 +162,22 @@ class chat_template { }), false); caps_.supports_tools = contains(out, "some_tool"); + const auto render_with_content = [&](const json & content) { + const json assistant_msg {{"role", "assistant"}, {"content", content}}; + // Render two assistant messages as some templates like QwQ-32B are handling + // the content differently depending on whether it's the last message or not + // (to remove the tag in all but the last message). + return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false); + }; + auto out_empty = render_with_content(""); + auto out_null = render_with_content(json()); + caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); + + json j_null; auto make_tool_calls_msg = [&](const json & tool_calls) { return json { {"role", "assistant"}, - {"content", nullptr}, + {"content", caps_.requires_non_null_content? "" : j_null}, {"tool_calls", tool_calls}, }; }; @@ -159,18 +198,15 @@ class chat_template { dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), }), {}, false); - auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + auto tool_call_renders_str_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), }), {}, false); - auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + auto tool_call_renders_obj_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; - auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); - auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); - caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); if (caps_.supports_tool_calls) { auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); @@ -195,6 +231,72 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json args { + {"arg1", "some_value"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", caps_.requires_non_null_content ? "" : j_null}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, + }}, + }, + })}, + }; + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } + auto eos_pos_last = full.rfind(eos_token_); + if (eos_pos_last == prefix.size() - eos_token_.size() || + (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { + full = full.substr(0, eos_pos_last); + } + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; + } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; + } + auto example = full.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; + } + } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } } const std::string & source() const { return source_; } @@ -202,29 +304,72 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + std::string apply( const chat_template_inputs & inputs, - // const nlohmann::ordered_json & messages, - // const nlohmann::ordered_json & tools, - // bool add_generation_prompt, - // const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool adjust_inputs = true) const + const chat_template_options & opts = chat_template_options()) const { json actual_messages; - auto needs_adjustments = adjust_inputs && (false - || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - || caps_.requires_typed_content + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message.contains("role") && message["role"] == "tool") { + has_tool_responses = true; + } + if (message.contains("content") && message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + + auto needs_polyfills = opts.apply_polyfills && (false + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content ); - if (needs_adjustments) { + + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -247,17 +392,25 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !inputs.tools.is_null() && inputs.tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(inputs.messages, "Available tools: " + inputs.tools.dump(2)) : inputs.messages) { + json adjusted_messages; + if (polyfill_tools) { + adjusted_messages = add_system(inputs.messages, + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); + } else { + adjusted_messages = inputs.messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; - if (!message.contains("role") || !message.contains("content")) { - throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) { + throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump()); } std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + if (polyfill_object_arguments || polyfill_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -272,8 +425,7 @@ class chat_template { } } } - if (!caps_.supports_tool_calls) { - auto content = message.at("content"); + if (polyfill_tool_calls) { auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { if (tool_call.at("type") != "function") { @@ -292,23 +444,25 @@ class chat_template { auto obj = json { {"tool_calls", tool_calls}, }; - if (!content.is_null() && content != "") { - obj["content"] = content; + if (message.contains("content")) { + auto content = message.at("content"); + if (!content.is_null() && !content.empty()) { + obj["content"] = content; + } } message["content"] = obj.dump(2); message.erase("tool_calls"); } } - if (!caps_.supports_tool_responses && role == "tool") { + if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { - {"tool_response", { - {"content", message.at("content")}, - }}, + {"tool_response", json::object()}, }; if (message.contains("name")) { - obj["tool_response"]["name"] = message.at("name"); + obj["tool_response"]["tool"] = message.at("name"); } + obj["tool_response"]["content"] = message.at("content"); if (message.contains("tool_call_id")) { obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); } @@ -316,7 +470,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !caps_.supports_system_role) { + if (!message["content"].is_null() && polyfill_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -335,9 +489,7 @@ class chat_template { } add_message(message); } - if (!caps_.supports_system_role) { - flush_sys(); - } + flush_sys(); } else { actual_messages = inputs.messages; } @@ -345,23 +497,28 @@ class chat_template { auto context = minja::Context::make(json({ {"messages", actual_messages}, {"add_generation_prompt", inputs.add_generation_prompt}, - {"bos_token", bos_token_}, - {"eos_token", eos_token_}, - // {"strftime_now", Value::callable([=](const std::shared_ptr & context, minja::ArgumentsValue & args) { - // args.expectArgs("strftime_now", {1, 1}, {0, 0}); - // auto format = args.args[0].get(); - // return Value(std::to_string(inputs.now)); - // })}, })); - + context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); + context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); + + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); + } if (!inputs.tools.is_null()) { - auto tools_val = minja::Value(inputs.tools); - context->set("tools", tools_val); + context->set("tools", minja::Value(inputs.tools)); } if (!inputs.extra_context.is_null()) { for (auto & kv : inputs.extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); + context->set(kv.key(), minja::Value(kv.value())); } } @@ -374,11 +531,11 @@ class chat_template { static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { json messages_with_system = messages; - if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, + {"content", existing_system + "\n\n" + system_prompt}, }; } else { messages_with_system.insert(messages_with_system.begin(), json { diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index f0e80fd..5ed0556 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -8,15 +8,28 @@ // SPDX-License-Identifier: MIT #pragma once +#include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include +#include +#include #include -#include +#include #include +#include +#include +#include #include -#include +#include +#include + +#include using json = nlohmann::ordered_json; @@ -221,7 +234,7 @@ class Value : public std::enable_shared_from_this { } } else if (is_object()) { if (!index.is_hashable()) - throw std::runtime_error("Unashable type: " + index.dump()); + throw std::runtime_error("Unhashable type: " + index.dump()); auto it = object_->find(index.primitive_); if (it == object_->end()) throw std::runtime_error("Key not found: " + index.dump()); @@ -240,7 +253,7 @@ class Value : public std::enable_shared_from_this { auto index = key.get(); return array_->at(index < 0 ? array_->size() + index : index); } else if (object_) { - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); auto it = object_->find(key.primitive_); if (it == object_->end()) return Value(); return it->second; @@ -249,7 +262,7 @@ class Value : public std::enable_shared_from_this { } void set(const Value& key, const Value& value) { if (!object_) throw std::runtime_error("Value is not an object: " + dump()); - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); (*object_)[key.primitive_] = value; } Value call(const std::shared_ptr & context, ArgumentsValue & args) const { @@ -386,7 +399,7 @@ class Value : public std::enable_shared_from_this { } return false; } else if (object_) { - if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump()); return object_->find(value.primitive_) != object_->end(); } else { throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); @@ -404,7 +417,7 @@ class Value : public std::enable_shared_from_this { return const_cast(this)->at(index); } Value& at(const Value & index) { - if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); if (is_array()) return array_->at(index.get()); if (is_object()) return object_->at(index.primitive_); throw std::runtime_error("Value is not an array or object: " + dump()); @@ -664,8 +677,8 @@ class Expression { class VariableExpr : public Expression { std::string name; public: - VariableExpr(const Location & location, const std::string& n) - : Expression(location), name(n) {} + VariableExpr(const Location & loc, const std::string& n) + : Expression(loc), name(n) {} std::string get_name() const { return name; } Value do_evaluate(const std::shared_ptr & context) const override { if (!context->contains(name)) { @@ -693,7 +706,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall }; static std::string typeToString(Type t) { switch (t) { @@ -714,6 +727,10 @@ class TemplateToken { case Type::EndFilter: return "endfilter"; case Type::Generation: return "generation"; case Type::EndGeneration: return "endgeneration"; + case Type::Break: return "break"; + case Type::Continue: return "continue"; + case Type::Call: return "call"; + case Type::EndCall: return "endcall"; } return "Unknown"; } @@ -729,51 +746,51 @@ class TemplateToken { struct TextTemplateToken : public TemplateToken { std::string text; - TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} + TextTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, loc, pre, post), text(t) {} }; struct ExpressionTemplateToken : public TemplateToken { std::shared_ptr expr; - ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} + ExpressionTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, loc, pre, post), expr(std::move(e)) {} }; struct IfTemplateToken : public TemplateToken { std::shared_ptr condition; - IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} + IfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, loc, pre, post), condition(std::move(c)) {} }; struct ElifTemplateToken : public TemplateToken { std::shared_ptr condition; - ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} + ElifTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, loc, pre, post), condition(std::move(c)) {} }; struct ElseTemplateToken : public TemplateToken { - ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} + ElseTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, loc, pre, post) {} }; struct EndIfTemplateToken : public TemplateToken { - EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} + EndIfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, loc, pre, post) {} }; struct MacroTemplateToken : public TemplateToken { std::shared_ptr name; Expression::Parameters params; - MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) - : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} + MacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, loc, pre, post), name(std::move(n)), params(std::move(p)) {} }; struct EndMacroTemplateToken : public TemplateToken { - EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} + EndMacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, loc, pre, post) {} }; struct FilterTemplateToken : public TemplateToken { std::shared_ptr filter; - FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) - : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} + FilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, loc, pre, post), filter(std::move(filter)) {} }; struct EndFilterTemplateToken : public TemplateToken { - EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} + EndFilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, loc, pre, post) {} }; struct ForTemplateToken : public TemplateToken { @@ -781,38 +798,65 @@ struct ForTemplateToken : public TemplateToken { std::shared_ptr iterable; std::shared_ptr condition; bool recursive; - ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + ForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, std::shared_ptr && c, bool r) - : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} + : TemplateToken(Type::For, loc, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} }; struct EndForTemplateToken : public TemplateToken { - EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} + EndForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, loc, pre, post) {} }; struct GenerationTemplateToken : public TemplateToken { - GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} + GenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, loc, pre, post) {} }; struct EndGenerationTemplateToken : public TemplateToken { - EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} + EndGenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, loc, pre, post) {} }; struct SetTemplateToken : public TemplateToken { std::string ns; std::vector var_names; std::shared_ptr value; - SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} + SetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, loc, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} }; struct EndSetTemplateToken : public TemplateToken { - EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} + EndSetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, loc, pre, post) {} }; struct CommentTemplateToken : public TemplateToken { std::string text; - CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} + CommentTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, loc, pre, post), text(t) {} +}; + +enum class LoopControlType { Break, Continue }; + +class LoopControlException : public std::runtime_error { +public: + LoopControlType control_type; + LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {} + LoopControlException(LoopControlType control_type) + : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")), + control_type(control_type) {} +}; + +struct LoopControlTemplateToken : public TemplateToken { + LoopControlType control_type; + LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} +}; + +struct CallTemplateToken : public TemplateToken { + std::shared_ptr expr; + CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) + : TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {} +}; + +struct EndCallTemplateToken : public TemplateToken { + EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) + : TemplateToken(Type::EndCall, loc, pre, post) {} }; class TemplateNode { @@ -825,6 +869,12 @@ class TemplateNode { void render(std::ostringstream & out, const std::shared_ptr & context) const { try { do_render(out, context); + } catch (const LoopControlException & e) { + // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop. + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw LoopControlException(err.str(), e.control_type); } catch (const std::exception & e) { std::ostringstream err; err << e.what(); @@ -844,8 +894,8 @@ class TemplateNode { class SequenceNode : public TemplateNode { std::vector> children; public: - SequenceNode(const Location & location, std::vector> && c) - : TemplateNode(location), children(std::move(c)) {} + SequenceNode(const Location & loc, std::vector> && c) + : TemplateNode(loc), children(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& child : children) child->render(out, context); } @@ -854,7 +904,7 @@ class SequenceNode : public TemplateNode { class TextNode : public TemplateNode { std::string text; public: - TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + TextNode(const Location & loc, const std::string& t) : TemplateNode(loc), text(t) {} void do_render(std::ostringstream & out, const std::shared_ptr &) const override { out << text; } @@ -863,7 +913,7 @@ class TextNode : public TemplateNode { class ExpressionNode : public TemplateNode { std::shared_ptr expr; public: - ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + ExpressionNode(const Location & loc, std::shared_ptr && e) : TemplateNode(loc), expr(std::move(e)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); auto result = expr->evaluate(context); @@ -880,8 +930,8 @@ class ExpressionNode : public TemplateNode { class IfNode : public TemplateNode { std::vector, std::shared_ptr>> cascade; public: - IfNode(const Location & location, std::vector, std::shared_ptr>> && c) - : TemplateNode(location), cascade(std::move(c)) {} + IfNode(const Location & loc, std::vector, std::shared_ptr>> && c) + : TemplateNode(loc), cascade(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& branch : cascade) { auto enter_branch = true; @@ -897,6 +947,15 @@ class IfNode : public TemplateNode { } }; +class LoopControlNode : public TemplateNode { + LoopControlType control_type_; + public: + LoopControlNode(const Location & loc, LoopControlType control_type) : TemplateNode(loc), control_type_(control_type) {} + void do_render(std::ostringstream &, const std::shared_ptr &) const override { + throw LoopControlException(control_type_); + } +}; + class ForNode : public TemplateNode { std::vector var_names; std::shared_ptr iterable; @@ -905,9 +964,9 @@ class ForNode : public TemplateNode { bool recursive; std::shared_ptr else_body; public: - ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + ForNode(const Location & loc, std::vector && var_names, std::shared_ptr && iterable, std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) - : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + : TemplateNode(loc), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { // https://jinja.palletsprojects.com/en/3.0.x/templates/#for @@ -961,7 +1020,12 @@ class ForNode : public TemplateNode { loop.set("last", i == (n - 1)); loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); - body->render(out, loop_context); + try { + body->render(out, loop_context); + } catch (const LoopControlException & e) { + if (e.control_type == LoopControlType::Break) break; + if (e.control_type == LoopControlType::Continue) continue; + } } } }; @@ -987,8 +1051,8 @@ class MacroNode : public TemplateNode { std::shared_ptr body; std::unordered_map named_param_positions; public: - MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) - : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + MacroNode(const Location & loc, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(loc), name(std::move(n)), params(std::move(p)), body(std::move(b)) { for (size_t i = 0; i < params.size(); ++i) { const auto & name = params[i].first; if (!name.empty()) { @@ -999,31 +1063,36 @@ class MacroNode : public TemplateNode { void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { if (!name) throw std::runtime_error("MacroNode.name is null"); if (!body) throw std::runtime_error("MacroNode.body is null"); - auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { - auto call_context = macro_context; + auto callable = Value::callable([this, macro_context](const std::shared_ptr & call_context, ArgumentsValue & args) { + auto execution_context = Context::make(Value::object(), macro_context); + + if (call_context->contains("caller")) { + execution_context->set("caller", call_context->get("caller")); + } + std::vector param_set(params.size(), false); for (size_t i = 0, n = args.args.size(); i < n; i++) { auto & arg = args.args[i]; if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); param_set[i] = true; auto & param_name = params[i].first; - call_context->set(param_name, arg); + execution_context->set(param_name, arg); } for (auto & [arg_name, value] : args.kwargs) { auto it = named_param_positions.find(arg_name); if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); - call_context->set(arg_name, value); + execution_context->set(arg_name, value); param_set[it->second] = true; } // Set default values for parameters that were not passed for (size_t i = 0, n = params.size(); i < n; i++) { if (!param_set[i] && params[i].second != nullptr) { - auto val = params[i].second->evaluate(context); - call_context->set(params[i].first, val); + auto val = params[i].second->evaluate(call_context); + execution_context->set(params[i].first, val); } } - return body->render(call_context); + return body->render(execution_context); }); macro_context->set(name->get_name(), callable); } @@ -1034,8 +1103,8 @@ class FilterNode : public TemplateNode { std::shared_ptr body; public: - FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) - : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + FilterNode(const Location & loc, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(loc), filter(std::move(f)), body(std::move(b)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { if (!filter) throw std::runtime_error("FilterNode.filter is null"); @@ -1057,8 +1126,8 @@ class SetNode : public TemplateNode { std::vector var_names; std::shared_ptr value; public: - SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) - : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + SetNode(const Location & loc, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(loc), ns(ns), var_names(vns), value(std::move(v)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!value) throw std::runtime_error("SetNode.value is null"); if (!ns.empty()) { @@ -1080,8 +1149,8 @@ class SetTemplateNode : public TemplateNode { std::string name; std::shared_ptr template_value; public: - SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) - : TemplateNode(location), name(name), template_value(std::move(tv)) {} + SetTemplateNode(const Location & loc, const std::string & name, std::shared_ptr && tv) + : TemplateNode(loc), name(name), template_value(std::move(tv)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); Value value { template_value->render(context) }; @@ -1094,8 +1163,8 @@ class IfExpr : public Expression { std::shared_ptr then_expr; std::shared_ptr else_expr; public: - IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) - : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + IfExpr(const Location & loc, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(loc), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!condition) throw std::runtime_error("IfExpr.condition is null"); if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); @@ -1112,16 +1181,16 @@ class IfExpr : public Expression { class LiteralExpr : public Expression { Value value; public: - LiteralExpr(const Location & location, const Value& v) - : Expression(location), value(v) {} + LiteralExpr(const Location & loc, const Value& v) + : Expression(loc), value(v) {} Value do_evaluate(const std::shared_ptr &) const override { return value; } }; class ArrayExpr : public Expression { std::vector> elements; public: - ArrayExpr(const Location & location, std::vector> && e) - : Expression(location), elements(std::move(e)) {} + ArrayExpr(const Location & loc, std::vector> && e) + : Expression(loc), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::array(); for (const auto& e : elements) { @@ -1135,8 +1204,8 @@ class ArrayExpr : public Expression { class DictExpr : public Expression { std::vector, std::shared_ptr>> elements; public: - DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) - : Expression(location), elements(std::move(e)) {} + DictExpr(const Location & loc, std::vector, std::shared_ptr>> && e) + : Expression(loc), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::object(); for (const auto& [key, value] : elements) { @@ -1150,9 +1219,9 @@ class DictExpr : public Expression { class SliceExpr : public Expression { public: - std::shared_ptr start, end; - SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) - : Expression(location), start(std::move(s)), end(std::move(e)) {} + std::shared_ptr start, end, step; + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e, std::shared_ptr && st = nullptr) + : Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); } @@ -1162,25 +1231,42 @@ class SubscriptExpr : public Expression { std::shared_ptr base; std::shared_ptr index; public: - SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) - : Expression(location), base(std::move(b)), index(std::move(i)) {} + SubscriptExpr(const Location & loc, std::shared_ptr && b, std::shared_ptr && i) + : Expression(loc), base(std::move(b)), index(std::move(i)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!base) throw std::runtime_error("SubscriptExpr.base is null"); if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + auto len = target_value.size(); + auto wrap = [len](int64_t i) -> int64_t { + if (i < 0) { + return i + len; + } + return i; + }; + int64_t step = slice->step ? slice->step->evaluate(context).get() : 1; + if (!step) { + throw std::runtime_error("slice step cannot be zero"); + } + int64_t start = slice->start ? wrap(slice->start->evaluate(context).get()) : (step < 0 ? len - 1 : 0); + int64_t end = slice->end ? wrap(slice->end->evaluate(context).get()) : (step < 0 ? -1 : len); if (target_value.is_string()) { std::string s = target_value.get(); - if (start < 0) start = s.size() + start; - if (end < 0) end = s.size() + end; - return s.substr(start, end - start); - } else if (target_value.is_array()) { - if (start < 0) start = target_value.size() + start; - if (end < 0) end = target_value.size() + end; + + std::string result; + if (start < end && step == 1) { + result = s.substr(start, end - start); + } else { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { + result += s[i]; + } + } + return result; + + } else if (target_value.is_array()) { auto result = Value::array(); - for (auto i = start; i < end; ++i) { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { result.push_back(target_value.at(i)); } return result; @@ -1205,8 +1291,8 @@ class UnaryOpExpr : public Expression { enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; std::shared_ptr expr; Op op; - UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) - : Expression(location), expr(std::move(e)), op(o) {} + UnaryOpExpr(const Location & loc, std::shared_ptr && e, Op o) + : Expression(loc), expr(std::move(e)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); auto e = expr->evaluate(context); @@ -1223,6 +1309,12 @@ class UnaryOpExpr : public Expression { } }; +static bool in(const Value & value, const Value & container) { + return (((container.is_array() || container.is_object()) && container.contains(value)) || + (value.is_string() && container.is_string() && + container.to_str().find(value.to_str()) != std::string::npos)); +}; + class BinaryOpExpr : public Expression { public: enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; @@ -1231,8 +1323,8 @@ class BinaryOpExpr : public Expression { std::shared_ptr right; Op op; public: - BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) - : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + BinaryOpExpr(const Location & loc, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(loc), left(std::move(l)), right(std::move(r)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); @@ -1255,6 +1347,8 @@ class BinaryOpExpr : public Expression { if (name == "iterable") return l.is_iterable(); if (name == "sequence") return l.is_array(); if (name == "defined") return !l.is_null(); + if (name == "true") return l.to_bool(); + if (name == "false") return !l.to_bool(); throw std::runtime_error("Unknown type for 'is' operator: " + name); }; auto value = eval(); @@ -1285,8 +1379,8 @@ class BinaryOpExpr : public Expression { case Op::Gt: return l > r; case Op::Le: return l <= r; case Op::Ge: return l >= r; - case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); - case Op::NotIn: return !(r.is_array() && r.contains(l)); + case Op::In: return in(l, r); + case Op::NotIn: return !in(l, r); default: break; } throw std::runtime_error("Unknown binary operator"); @@ -1340,13 +1434,34 @@ struct ArgumentsExpression { } }; -static std::string strip(const std::string & s) { - auto start = s.find_first_not_of(" \t\n\r"); +static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) { + auto charset = chars.empty() ? " \t\n\r" : chars; + auto start = left ? s.find_first_not_of(charset) : 0; if (start == std::string::npos) return ""; - auto end = s.find_last_not_of(" \t\n\r"); + auto end = right ? s.find_last_not_of(charset) : s.size() - 1; return s.substr(start, end - start + 1); } +static std::vector split(const std::string & s, const std::string & sep) { + std::vector result; + size_t start = 0; + size_t end = s.find(sep); + while (end != std::string::npos) { + result.push_back(s.substr(start, end - start)); + start = end + sep.length(); + end = s.find(sep, start); + } + result.push_back(s.substr(start)); + return result; +} + +static std::string capitalize(const std::string & s) { + if (s.empty()) return s; + auto result = s; + result[0] = std::toupper(result[0]); + return result; +} + static std::string html_escape(const std::string & s) { std::string result; result.reserve(s.size()); @@ -1368,8 +1483,8 @@ class MethodCallExpr : public Expression { std::shared_ptr method; ArgumentsExpression args; public: - MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + MethodCallExpr(const Location & loc, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(loc), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!object) throw std::runtime_error("MethodCallExpr.object is null"); if (!method) throw std::runtime_error("MethodCallExpr.method is null"); @@ -1404,6 +1519,13 @@ class MethodCallExpr : public Expression { } else if (method->get_name() == "pop") { vargs.expectArgs("pop method", {1, 1}, {0, 0}); return obj.pop(vargs.args[0]); + } else if (method->get_name() == "keys") { + vargs.expectArgs("keys method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value(key)); + } + return result; } else if (method->get_name() == "get") { vargs.expectArgs("get method", {1, 2}, {0, 0}); auto key = vargs.args[0]; @@ -1422,12 +1544,47 @@ class MethodCallExpr : public Expression { } else if (obj.is_string()) { auto str = obj.get(); if (method->get_name() == "strip") { - vargs.expectArgs("strip method", {0, 0}, {0, 0}); - return Value(strip(str)); + vargs.expectArgs("strip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars)); + } else if (method->get_name() == "lstrip") { + vargs.expectArgs("lstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ true, /* right= */ false)); + } else if (method->get_name() == "rstrip") { + vargs.expectArgs("rstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ false, /* right= */ true)); + } else if (method->get_name() == "split") { + vargs.expectArgs("split method", {1, 1}, {0, 0}); + auto sep = vargs.args[0].get(); + auto parts = split(str, sep); + Value result = Value::array(); + for (const auto& part : parts) { + result.push_back(Value(part)); + } + return result; + } else if (method->get_name() == "capitalize") { + vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); + return Value(capitalize(str)); + } else if (method->get_name() == "upper") { + vargs.expectArgs("upper method", {0, 0}, {0, 0}); + auto result = str; + std::transform(result.begin(), result.end(), result.begin(), ::toupper); + return Value(result); + } else if (method->get_name() == "lower") { + vargs.expectArgs("lower method", {0, 0}, {0, 0}); + auto result = str; + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return Value(result); } else if (method->get_name() == "endswith") { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "startswith") { + vargs.expectArgs("startswith method", {1, 1}, {0, 0}); + auto prefix = vargs.args[0].get(); + return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin()); } else if (method->get_name() == "title") { vargs.expectArgs("title method", {0, 0}, {0, 0}); auto res = str; @@ -1436,6 +1593,19 @@ class MethodCallExpr : public Expression { else res[i] = std::tolower(res[i]); } return res; + } else if (method->get_name() == "replace") { + vargs.expectArgs("replace method", {2, 3}, {0, 0}); + auto before = vargs.args[0].get(); + auto after = vargs.args[1].get(); + auto count = vargs.args.size() == 3 ? vargs.args[2].get() + : str.length(); + size_t start_pos = 0; + while ((start_pos = str.find(before, start_pos)) != std::string::npos && + count-- > 0) { + str.replace(start_pos, before.length(), after); + start_pos += after.length(); + } + return str; } } throw std::runtime_error("Unknown method: " + method->get_name()); @@ -1446,8 +1616,8 @@ class CallExpr : public Expression { public: std::shared_ptr object; ArgumentsExpression args; - CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) - : Expression(location), object(std::move(obj)), args(std::move(a)) {} + CallExpr(const Location & loc, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(loc), object(std::move(obj)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { if (!object) throw std::runtime_error("CallExpr.object is null"); auto obj = object->evaluate(context); @@ -1459,11 +1629,45 @@ class CallExpr : public Expression { } }; +class CallNode : public TemplateNode { + std::shared_ptr expr; + std::shared_ptr body; + +public: + CallNode(const Location & loc, std::shared_ptr && e, std::shared_ptr && b) + : TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("CallNode.expr is null"); + if (!body) throw std::runtime_error("CallNode.body is null"); + + auto caller = Value::callable([this, context](const std::shared_ptr &, ArgumentsValue &) -> Value { + return Value(body->render(context)); + }); + + context->set("caller", caller); + + auto call_expr = dynamic_cast(expr.get()); + if (!call_expr) { + throw std::runtime_error("Invalid call block syntax - expected function call"); + } + + Value function = call_expr->object->evaluate(context); + if (!function.is_callable()) { + throw std::runtime_error("Call target must be callable: " + function.dump()); + } + ArgumentsValue args = call_expr->args.evaluate(context); + + Value result = function.call(context, args); + out << result.to_str(); + } +}; + class FilterExpr : public Expression { std::vector> parts; public: - FilterExpr(const Location & location, std::vector> && p) - : Expression(location), parts(std::move(p)) {} + FilterExpr(const Location & loc, std::vector> && p) + : Expression(loc), parts(std::move(p)) {} Value do_evaluate(const std::shared_ptr & context) const override { Value result; bool first = true; @@ -1754,7 +1958,7 @@ class Parser { auto left = parseStringConcat(); if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); - static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); static std::regex not_tok(R"(not\b)"); std::string op_str; while (!(op_str = consumeToken(compare_tok)).empty()) { @@ -1990,28 +2194,37 @@ class Parser { while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { if (!consumeToken("[").empty()) { - std::shared_ptr index; + std::shared_ptr index; + auto slice_loc = get_location(); + std::shared_ptr start, end, step; + bool has_first_colon = false, has_second_colon = false; + + if (!peekSymbols({ ":" })) { + start = parseExpression(); + } + + if (!consumeToken(":").empty()) { + has_first_colon = true; + if (!peekSymbols({ ":", "]" })) { + end = parseExpression(); + } if (!consumeToken(":").empty()) { - auto slice_end = parseExpression(); - index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); - } else { - auto slice_start = parseExpression(); - if (!consumeToken(":").empty()) { - consumeSpaces(); - if (peekSymbols({ "]" })) { - index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); - } else { - auto slice_end = parseExpression(); - index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); - } - } else { - index = std::move(slice_start); + has_second_colon = true; + if (!peekSymbols({ "]" })) { + step = parseExpression(); } } - if (!index) throw std::runtime_error("Empty index in subscript"); - if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + } + + if ((has_first_colon || has_second_colon)) { + index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); + } else { + index = std::move(start); + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); - value = std::make_shared(value->location, std::move(value), std::move(index)); + value = std::make_shared(value->location, std::move(value), std::move(index)); } else if (!consumeToken(".").empty()) { auto identifier = parseIdentifier(); if (!identifier) throw std::runtime_error("Expected identifier in subscript"); @@ -2133,7 +2346,7 @@ class Parser { using TemplateTokenIterator = TemplateTokenVector::const_iterator; std::vector parseVarNames() { - static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); std::vector group; if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); @@ -2156,13 +2369,13 @@ class Parser { } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); - static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); - static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); - static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); + static std::regex block_close_regex(R"(\s*([-~])?%\})"); TemplateTokenVector tokens; std::vector group; @@ -2246,7 +2459,7 @@ class Parser { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "set") { - static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); std::string ns; std::vector var_names; @@ -2282,6 +2495,15 @@ class Parser { } else if (keyword == "endmacro") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "call") { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (keyword == "endcall") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "filter") { auto filter = parseExpression(); if (!filter) throw std::runtime_error("Expected expression in filter block"); @@ -2291,10 +2513,18 @@ class Parser { } else if (keyword == "endfilter") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "break" || keyword == "continue") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue)); } else { throw std::runtime_error("Unexpected block: " + keyword); } } else if (std::regex_search(it, end, match, non_text_open_regex)) { + if (!match.position()) { + if (match[0] != "{#") + throw std::runtime_error("Internal error: Expected a comment"); + throw std::runtime_error("Missing end of comment tag"); + } auto text_end = it + match.position(); text = std::string(it, text_end); it = text_end; @@ -2359,7 +2589,7 @@ class Parser { auto text = text_token->text; if (post_space == SpaceHandling::Strip) { - static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + static std::regex trailing_space_regex(R"(\s+$)"); text = std::regex_replace(text, trailing_space_regex, ""); } else if (options.lstrip_blocks && it != end) { auto i = text.size(); @@ -2369,10 +2599,10 @@ class Parser { } } if (pre_space == SpaceHandling::Strip) { - static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + static std::regex leading_space_regex(R"(^\s+)"); text = std::regex_replace(text, leading_space_regex, ""); } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { - if (text.length() > 0 && text[0] == '\n') { + if (!text.empty() && text[0] == '\n') { text.erase(0, 1); } } @@ -2406,6 +2636,12 @@ class Parser { throw unterminated(**start); } children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto call_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(call_token->expr), std::move(body))); } else if (auto filter_token = dynamic_cast(token.get())) { auto body = parseTemplate(begin, it, end); if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { @@ -2414,9 +2650,12 @@ class Parser { children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); } else if (dynamic_cast(token.get())) { // Ignore comments + } else if (auto ctrl_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, ctrl_token->control_type)); } else if (dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) + || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) @@ -2448,7 +2687,7 @@ class Parser { TemplateTokenIterator begin = tokens.begin(); auto it = begin; TemplateTokenIterator end = tokens.end(); - return parser.parseTemplate(begin, it, end, /* full= */ true); + return parser.parseTemplate(begin, it, end, /* fully= */ true); } }; @@ -2487,21 +2726,17 @@ inline std::shared_ptr Context::builtins() { throw std::runtime_error(args.at("message").get()); })); globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { - return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + return Value(args.at("value").dump(args.get("indent", -1), /* to_json= */ true)); })); globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { auto items = Value::array(); if (args.contains("object")) { auto & obj = args.at("object"); - if (obj.is_string()) { - auto json_obj = json::parse(obj.get()); - for (const auto & kv : json_obj.items()) { - items.push_back(Value::array({kv.key(), kv.value()})); - } - } else if (!obj.is_null()) { - for (auto & key : obj.keys()) { - items.push_back(Value::array({key, obj.at(key)})); - } + if (!obj.is_object()) { + throw std::runtime_error("Can only get item pairs from a mapping"); + } + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); } } return items; @@ -2509,21 +2744,25 @@ inline std::shared_ptr Context::builtins() { globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { auto items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not a list"); - if (items.size() == 0) return Value(); + if (items.empty()) return Value(); return items.at(items.size() - 1); })); globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { auto & text = args.at("text"); return text.is_null() ? text : Value(strip(text.get())); })); - globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { - auto text = args.at("text"); - if (text.is_null()) return text; - std::string res; - auto str = text.get(); - std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); - return Value(res); - })); + auto char_transform_function = [](const std::string & name, const std::function & fn) { + return simple_function(name, { "text" }, [=](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), fn); + return Value(res); + }); + }; + globals.set("lower", char_transform_function("lower", ::tolower)); + globals.set("upper", char_transform_function("upper", ::toupper)); globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { args.expectArgs("default", {2, 3}, {0, 1}); auto & value = args.args[0]; @@ -2572,6 +2811,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); std::ostringstream oss; auto first = true; for (size_t i = 0, n = items.size(); i < n; ++i) { @@ -2624,6 +2864,9 @@ inline std::shared_ptr Context::builtins() { if (!items.is_array()) throw std::runtime_error("object is not iterable"); return items; })); + globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr &, Value & args) -> Value { + return in(args.at("item"), args.at("items")); + })); globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { auto & items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not iterable"); @@ -2652,8 +2895,17 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; + if (items.is_null()) { + return Value::array(); + } + if (!items.is_array()) { + throw std::runtime_error("object is not iterable: " + items.dump()); + } + auto filter_fn = context->get(args.args[1]); - if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + if (filter_fn.is_null()) { + throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + } auto filter_args = Value::array(); for (size_t i = 2, n = args.args.size(); i < n; i++) { @@ -2729,6 +2981,7 @@ inline std::shared_ptr Context::builtins() { auto & items = args.args[0]; if (items.is_null()) return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); auto attr_name = args.args[1].get(); bool has_test = false; @@ -2774,20 +3027,25 @@ inline std::shared_ptr Context::builtins() { auto v = arg.get(); startEndStep[i] = v; param_set[i] = true; - } } - for (auto & [name, value] : args.kwargs) { - size_t i; - if (name == "start") i = 0; - else if (name == "end") i = 1; - else if (name == "step") i = 2; - else throw std::runtime_error("Unknown argument " + name + " for function range"); - - if (param_set[i]) { - throw std::runtime_error("Duplicate argument " + name + " for function range"); - } - startEndStep[i] = value.get(); - param_set[i] = true; + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") { + i = 0; + } else if (name == "end") { + i = 1; + } else if (name == "step") { + i = 2; + } else { + throw std::runtime_error("Unknown argument " + name + " for function range"); + } + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; } if (!param_set[1]) { throw std::runtime_error("Missing required argument 'end' for function range"); diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 619b539..acaf969 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -43,9 +43,6 @@ def raise_exception(message: str): raise ValueError(message) -def tojson(eval_ctx, value, indent=None): - return json.dumps(value, indent=indent) - TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26') @@ -66,7 +63,7 @@ def add_system(messages, system_prompt): existing_system = messages[0]["content"] messages[0] = { "role": "system", - "content": existing_system + "\n" + system_prompt, + "content": existing_system + "\n\n" + system_prompt, } else: messages.insert(0, { @@ -86,7 +83,7 @@ class TemplateCaps: requires_object_arguments: bool = False requires_non_null_content: bool = False requires_typed_content: bool = False - + def to_json(self): return json.dumps({ "supports_tools": self.supports_tools, @@ -99,121 +96,269 @@ def to_json(self): # "requires_non_null_content": self.requires_non_null_content, "requires_typed_content": self.requires_typed_content, }, indent=2) - -def detect_caps(template_file, template): - basic_extra_context = { - "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>", - } - def try_raw_render(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + + +class chat_template: + + def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + basic_extra_context = { + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + } + try: - out = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) + out = self.template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) # print(out, file=sys.stderr) return out except BaseException as e: - # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) + # print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - - caps = TemplateCaps() - - user_needle = "" - sys_needle = "" - dummy_str_user_msg = {"role": "user", "content": user_needle } - dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} - - caps.requires_typed_content = \ - (user_needle not in try_raw_render([dummy_str_user_msg])) \ - and (user_needle in try_raw_render([dummy_typed_user_msg])) - dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg - - needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} - - caps.supports_system_role = sys_needle in try_raw_render([needle_system_msg, dummy_user_msg]) - - out = try_raw_render([dummy_user_msg], tools=[{ - "name": "some_tool", - "type": "function", - "function": { + + def __init__(self, template, env=None, filters=None, global_functions=None): + if not env: + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2.ext.loopcontrols] + ) + if filters: + for name, func in filters.items(): + env.filters[name] = func + if global_functions: + for name, func in global_functions.items(): + env.globals[name] = func + self.env = env + self.template = env.from_string(template) + + caps = TemplateCaps() + + user_needle = "" + sys_needle = "" + dummy_str_user_msg = {"role": "user", "content": user_needle } + dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} + + caps.requires_typed_content = \ + (user_needle not in self.try_raw_render([dummy_str_user_msg])) \ + and (user_needle in self.try_raw_render([dummy_typed_user_msg])) + dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg + + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} + + caps.supports_system_role = sys_needle in self.try_raw_render([needle_system_msg, dummy_user_msg]) + + out = self.try_raw_render([dummy_user_msg], tools=[{ "name": "some_tool", - "description": "Some tool", - "parameters": { - "type": "object", - "properties": { - "arg": { - "type": "string", - "description": "Some arg", + "type": "function", + "function": { + "name": "some_tool", + "description": "Some tool", + "parameters": { + "type": "object", + "properties": { + "arg": { + "type": "string", + "description": "Some arg", + }, }, + "required": ["arg"], }, - "required": ["arg"], }, - }, - }]) - caps.supports_tools = "some_tool" in out - - def make_tool_calls_msg(tool_calls, content=None): - return { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - def make_tool_call(tool_name, arguments): - return { - "id": "call_1___", - "type": "function", - "function": { - "arguments": arguments, - "name": tool_name, + }]) + caps.supports_tools = "some_tool" in out + + caps.requires_non_null_content = \ + (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ + and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) + + def make_tool_calls_msg(tool_calls, content=None): + return { + "role": "assistant", + "content": "" if content is None and caps.requires_non_null_content else content, + "tool_calls": tool_calls, + } + def make_tool_call(tool_name, arguments): + return { + "id": "call_1___", + "type": "function", + "function": { + "arguments": arguments, + "name": tool_name, + } } - } - - dummy_args_obj = {"argument_needle": "print('Hello, World!')"} - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), - ]) - tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), - ]) - tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out - - caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments - caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - - empty_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}]) - none_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]) - caps.requires_non_null_content = \ - (user_needle in try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ - and (user_needle not in try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) - - if caps.supports_tool_calls: - dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) - tc1 = make_tool_call("test_tool1", dummy_args) - tc2 = make_tool_call("test_tool2", dummy_args) - out = try_raw_render([ + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} + + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1, tc2]), + make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), ]) - caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out - - out = try_raw_render([ + tool_call_renders_str_arguments = "" in out or '"argument_needle":' in out or "'argument_needle':" in out + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1]), - { - "role": "tool", - "name": "test_tool1", - "content": "Some response!", - "tool_call_id": "call_911_", - } + make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) - caps.supports_tool_responses = "Some response!" in out - caps.supports_tool_call_id = "call_911_" in out - - return caps - -async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): + tool_call_renders_obj_arguments = "" in out or '"argument_needle":' in out or "'argument_needle':" in out + + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments + caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments + + if caps.supports_tool_calls: + dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) + tc1 = make_tool_call("test_tool1", dummy_args) + tc2 = make_tool_call("test_tool2", dummy_args) + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1, tc2]), + ]) + caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out + + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1]), + { + "role": "tool", + "name": "test_tool1", + "content": "Some response!", + "tool_call_id": "call_911_", + } + ]) + caps.supports_tool_responses = "Some response!" in out + caps.supports_tool_call_id = "call_911_" in out + + self.tool_call_example = None + try: + if not caps.supports_tools: + user_msg = {"role": "user", "content": "Hey"} + args = {"arg1": "some_value"} + tool_call_msg = { + "role": "assistant", + "content": "" if caps.requires_non_null_content else None, + "tool_calls": [ + { + "id": "call_1___", + "type": "function", + "function": { + "name": "tool_name", + "arguments": args if caps.requires_object_arguments else json.dumps(args), + }, + }, + ], + } + prefix = self.try_raw_render([user_msg], add_generation_prompt=True) + full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False) + + common_prefix_length = 0 + for i in range(min(len(prefix), len(full))): + if prefix[i] != full[i]: + break + if prefix[i] == '<': + # DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + # but it removes thinking tags for past messages. + # The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue + common_prefix_length = i + 1 + + example = full[common_prefix_length:] + if "tool_name" not in example and "some_value" not in example: + print("Failed to infer a tool call example (possible template bug)", file=sys.stderr) + else: + self.tool_call_example = example + + except Exception as e: + print(f"Failed to generate tool call example: {e}", file=sys.stderr) + + self.original_caps = caps + + def needs_polyfills(self, context): + has_tools = context.get('tools') is not None + caps = self.original_caps + return not caps.supports_system_role \ + or (has_tools is not None and (False \ + or not caps.supports_tools \ + or not caps.supports_tool_responses \ + or not caps.supports_tool_calls \ + or caps.requires_object_arguments \ + )) \ + or caps.requires_typed_content + + def apply(self, context: dict): + assert isinstance(context, dict) + context = json.loads(json.dumps(context)) + + caps = self.original_caps + has_tools = 'tools' in context + + if self.needs_polyfills(context): + if has_tools and not caps.supports_tools: + add_system(context['messages'], + f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" + + ("\n\nExample tool call syntax:\n\n" + self.tool_call_example + "\n\n" if self.tool_call_example is not None else "")) + + for message in context['messages']: + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if caps.requires_object_arguments: + if tool_call.get('type') == 'function': + arguments = tool_call['function']['arguments'] + try: + arguments = json.loads(arguments) + except: + pass + tool_call['function']['arguments'] = arguments + if not caps.supports_tool_calls: + message['content'] = json.dumps({ + "tool_calls": [ + { + "name": tc['function']['name'], + "arguments": json.loads(tc['function']['arguments']), + "id": tc.get('id'), + } + for tc in message['tool_calls'] + ], + "content": None if message.get('content', '') == '' else message['content'], + }, indent=2) + del message['tool_calls'] + if message.get('role') == 'tool' and not caps.supports_tool_responses: + message['role'] = 'user' + message['content'] = json.dumps({ + "tool_response": { + "tool": message['name'], + "content": message['content'], + "tool_call_id": message.get('tool_call_id'), + } + }, indent=2) + del message['name'] + + if caps.requires_typed_content: + for message in context['messages']: + if 'content' in message and isinstance(message['content'], str): + message['content'] = [{"type": "text", "text": message['content']}] + + try: + out = self.template.render(**context) + out = out.replace("\\u0027", "'") + out = out.replace('"', '"') + out = out.replace(''', "'") + return out + except Exception as e1: + for message in context['messages']: + if message.get("content") is None: + message["content"] = "" + + try: + return self.template.render(**context) + except Exception as e2: + logger.info(f" ERROR: {e2} (after first error: {e1})") + return f"ERROR: {e2}" + +@dataclass +class Context: + name: str + file: str + bindings: dict + + +async def handle_chat_template(output_folder, model_id, variant, template_src, contexts: list[Context]): if '{% generation %}' in template_src: print('Removing {% generation %} blocks from template', file=sys.stderr) template_src = template_src.replace('{% generation %}', '').replace('{% endgeneration %}', '') @@ -221,109 +366,48 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c model_name = model_id.replace("/", "-") base_name = f'{model_name}-{variant}' if variant else model_name template_file = join_cmake_path(output_folder, f'{base_name}.jinja') + caps_file = join_cmake_path(output_folder, f'{base_name}.caps.json') async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[jinja2.ext.loopcontrols] - ) - template = env.from_string(template_src) - - env.filters['safe'] = lambda x: x - env.filters['tojson'] = tojson - env.globals['raise_exception'] = raise_exception - env.globals['strftime_now'] = strftime_now - - caps = detect_caps(template_file, template) - - if not context_files: + template = chat_template(template_src, + filters={ + 'safe': lambda x: x, + }, + global_functions={ + 'raise_exception': raise_exception, + 'strftime_now': strftime_now, + }) + caps = template.original_caps + + if not contexts: print(f"{template_file} {caps_file} n/a {template_file}") return async with aiofiles.open(caps_file, 'w') as f: await f.write(caps.to_json()) - - for context_file in context_files: - context_name = os.path.basename(context_file).replace(".json", "") - async with aiofiles.open(context_file, 'r') as f: - context = json.loads(await f.read()) - has_tools = 'tools' in context - needs_tools_in_system = has_tools and not caps.supports_tools - - if not caps.supports_tool_calls and has_tools: - print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) - continue - - if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): + assert isinstance(contexts, list) + for context in contexts: + assert isinstance(context, Context) + assert isinstance(context.bindings, dict) + if not caps.supports_tool_calls and context.bindings.get('tools') is not None: + print(f'Skipping {context.name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue - output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') - - if needs_tools_in_system: - add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") - - for message in context['messages']: - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if caps.requires_object_arguments: - if tool_call.get('type') == 'function': - arguments = tool_call['function']['arguments'] - try: - arguments = json.loads(arguments) - except: - pass - tool_call['function']['arguments'] = arguments - if not caps.supports_tool_calls: - message['content'] = json.dumps({ - "tool_calls": [ - { - "name": tc['function']['name'], - "arguments": json.loads(tc['function']['arguments']), - "id": tc.get('id'), - } - for tc in message['tool_calls'] - ], - "content": None if message.get('content', '') == '' else message['content'], - }, indent=2) - del message['tool_calls'] - if message.get('role') == 'tool' and not caps.supports_tool_responses: - message['role'] = 'user' - message['content'] = json.dumps({ - "tool_response": { - "tool": message['name'], - "content": message['content'], - "tool_call_id": message.get('tool_call_id'), - } - }, indent=2) - del message['name'] - - if caps.requires_typed_content: - for message in context['messages']: - if 'content' in message and isinstance(message['content'], str): - message['content'] = [{"type": "text", "text": message['content']}] - - try: - output = template.render(**context) - except Exception as e1: - for message in context['messages']: - if message.get("content") is None: - message["content"] = "" + needs_tools_in_system = len(context.bindings.get('tools', [])) > 0 and not caps.supports_tools + if not caps.supports_system_role and (any(m['role'] == 'system' for m in context.bindings['messages']) or needs_tools_in_system): + continue - try: - output = template.render(**context) - except Exception as e2: - logger.info(f" ERROR: {e2} (after first error: {e1})") - output = f"ERROR: {e2}" + output_file = join_cmake_path(output_folder, f'{base_name}-{context.name}.txt') + output = template.apply(context.bindings) async with aiofiles.open(output_file, 'w') as f: await f.write(output) - print(f"{template_file} {caps_file} {context_file} {output_file}") + print(f"{template_file} {caps_file} {context.file} {output_file}") async def async_hf_download(repo_id: str, filename: str) -> str: headers = build_hf_headers() @@ -333,26 +417,37 @@ async def async_hf_download(repo_id: str, filename: str) -> str: response.raise_for_status() return await response.text() -async def process_model(output_folder: str, model_id: str, context_files: list): +async def process_model(output_folder: str, model_id: str, contexts: list[Context]): try: + print(f"Processing model {model_id}...", file=sys.stderr) config_str = await async_hf_download(model_id, "tokenizer_config.json") - + try: config = json.loads(config_str) except json.JSONDecodeError: config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) - assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json!' + if 'chat_template' not in config: + try: + chat_template = await async_hf_download(model_id, "chat_template.jinja") + config.update({'chat_template': chat_template}) + except Exception as e: + logger.error(f"Failed to fetch chat_template.jinja for model {model_id}: {e}") + raise e + + assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json or no chat_template.jinja file found!' chat_template = config['chat_template'] if isinstance(chat_template, str): - await handle_chat_template(output_folder, model_id, None, chat_template, context_files) + await handle_chat_template(output_folder, model_id, None, chat_template, contexts) else: await asyncio.gather(*[ - handle_chat_template(output_folder, model_id, ct['name'], ct['template'], context_files) + handle_chat_template(output_folder, model_id, ct['name'], ct['template'], contexts) for ct in chat_template ]) except Exception as e: logger.error(f"Error processing model {model_id}: {e}") + # import traceback + # traceback.print_exc() await handle_chat_template(output_folder, model_id, None, str(e), []) async def async_copy_file(src: str, dst: str): @@ -366,11 +461,15 @@ async def main(): parser.add_argument("json_context_files_or_model_ids", nargs="+", help="List of context JSON files or HuggingFace model IDs") args = parser.parse_args() - context_files = [] + contexts: list[Context] = [] model_ids = [] for file in args.json_context_files_or_model_ids: if file.endswith('.json'): - context_files.append(file) + async with aiofiles.open(file, 'r') as f: + contexts.append(Context( + name=os.path.basename(file).replace(".json", ""), + file=file, + bindings=json.loads(await f.read()))) else: model_ids.append(file) @@ -378,15 +477,10 @@ async def main(): if not os.path.isdir(output_folder): os.makedirs(output_folder) - # Copy context files to the output folder asynchronously - await asyncio.gather(*[ - async_copy_file(context_file, os.path.join(output_folder, os.path.basename(context_file))) - for context_file in context_files - ]) - - # Process models concurrently + # for model_id in model_ids: + # await process_model(output_folder, model_id, contexts) await asyncio.gather(*[ - process_model(output_folder, model_id, context_files) + process_model(output_folder, model_id, contexts) for model_id in model_ids ]) diff --git a/scripts/run_fuzzing_mode.sh b/scripts/fuzzing_tests.sh similarity index 100% rename from scripts/run_fuzzing_mode.sh rename to scripts/fuzzing_tests.sh diff --git a/scripts/run_tests.sh b/scripts/tests.sh similarity index 100% rename from scripts/run_tests.sh rename to scripts/tests.sh diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ccc942..db82c2d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,12 +13,45 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ar target_compile_options(gtest PRIVATE -Wno-language-extension-token) endif() target_link_libraries(test-syntax PRIVATE - nlohmann_json::nlohmann_json + minja + gtest_main + gmock +) + +if (WIN32) + message(STATUS "Skipping test-chat-template on Win32") +else() + add_executable(test-chat-template test-chat-template.cpp) + target_compile_features(test-chat-template PUBLIC cxx_std_17) + if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-chat-template PUBLIC _CRT_SECURE_NO_WARNINGS) + target_compile_options(gtest PRIVATE -Wno-language-extension-token) + endif() + target_link_libraries(test-chat-template PRIVATE + minja + gtest_main + gmock + ) +endif() + +add_executable(test-polyfills test-polyfills.cpp) +target_compile_features(test-polyfills PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-polyfills PUBLIC _CRT_SECURE_NO_WARNINGS) + target_compile_options(gtest PRIVATE -Wno-language-extension-token) +endif() +target_link_libraries(test-polyfills PRIVATE + minja gtest_main gmock ) if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) + if (NOT WIN32) + gtest_discover_tests(test-chat-template) + endif() + add_test(NAME test-polyfills COMMAND test-polyfills) + set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endif() add_executable(test-capabilities test-capabilities.cpp) @@ -28,7 +61,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ar target_compile_options(gtest PRIVATE -Wno-language-extension-token) endif() target_link_libraries(test-capabilities PRIVATE - nlohmann_json::nlohmann_json + minja gtest_main gmock ) @@ -39,9 +72,15 @@ add_test(NAME test-syntax-jinja2 COMMAND test-syntax) set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}") -add_executable(test-chat-template test-chat-template.cpp) -target_compile_features(test-chat-template PUBLIC cxx_std_17) -target_link_libraries(test-chat-template PRIVATE nlohmann_json::nlohmann_json) +add_executable(test-supported-template test-supported-template.cpp) +target_compile_features(test-supported-template PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-supported-template PUBLIC _CRT_SECURE_NO_WARNINGS) +endif() +target_link_libraries(test-supported-template PRIVATE minja) + +# https://huggingface.co/models?other=conversational +# https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/?types=fine-tuned%2Cchat set(MODEL_IDS # List of model IDs to test the chat template of. @@ -51,62 +90,258 @@ set(MODEL_IDS # minja implementation on the same template and context, and compare the output with the golden. # # For Gated models, you'll need to run `huggingface-cli login` (and be granted access) to download their template. - + abacusai/Fewshot-Metamath-OrcaVicuna-Mistral + allenai/Llama-3.1-Tulu-3-405B + allenai/Llama-3.1-Tulu-3-405B-SFT + allenai/Llama-3.1-Tulu-3-8B + arcee-ai/Virtuoso-Lite + arcee-ai/Virtuoso-Medium-v2 + arcee-ai/Virtuoso-Small-v2 + AtlaAI/Selene-1-Mini-Llama-3.1-8B + avemio/GRAG-NEMO-12B-ORPO-HESSIAN-AI + BEE-spoke-data/tFINE-900m-instruct-orpo + bespokelabs/Bespoke-Stratos-7B + bfuzzy1/acheron-m1a-llama bofenghuang/vigogne-2-70b-chat - CohereForAI/c4ai-command-r-plus # Gated - databricks/dbrx-instruct # Gated - google/gemma-2-2b-it # Gated - google/gemma-7b-it # Gated - MiniMaxAI/MiniMax-Text-01 + bytedance-research/UI-TARS-72B-DPO + bytedance-research/UI-TARS-7B-DPO + bytedance-research/UI-TARS-7B-SFT + carsenk/phi3.5_mini_exp_825_uncensored + CohereForAI/aya-expanse-8b + CohereForAI/c4ai-command-r-plus + CohereForAI/c4ai-command-r7b-12-2024 + cyberagent/DeepSeek-R1-Distill-Qwen-14B-Japanese + cyberagent/DeepSeek-R1-Distill-Qwen-32B-Japanese + databricks/dbrx-instruct + DavieLion/Llama-3.2-1B-SPIN-iter3 + deepseek-ai/deepseek-coder-33b-instruct + deepseek-ai/deepseek-coder-6.7b-instruct + deepseek-ai/deepseek-coder-7b-instruct-v1.5 + deepseek-ai/DeepSeek-Coder-V2-Instruct + deepseek-ai/DeepSeek-Coder-V2-Lite-Base + deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct + deepseek-ai/deepseek-llm-67b-chat + deepseek-ai/deepseek-llm-7b-chat + deepseek-ai/DeepSeek-R1-Distill-Llama-70B + deepseek-ai/DeepSeek-R1-Distill-Llama-8B + deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + deepseek-ai/DeepSeek-R1-Distill-Qwen-14B + deepseek-ai/DeepSeek-R1-Distill-Qwen-32B + deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + deepseek-ai/DeepSeek-V2-Lite + deepseek-ai/DeepSeek-V2.5 + deepseek-ai/DeepSeek-V3 + Delta-Vector/Rei-12B + dicta-il/dictalm2.0-instruct + ehristoforu/Falcon3-8B-Franken-Basestruct + EpistemeAI/Mistral-Nemo-Instruct-12B-Philosophy-Math + FlofloB/83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit + FlofloB/test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit + godlikehhd/alpaca_data_sampled_ifd_new_5200 + godlikehhd/alpaca_data_score_max_0.7_2600 + google/gemma-2-27b-it + google/gemma-2-2b-it + google/gemma-2-2b-jpn-it + google/gemma-7b-it + HelpingAI/HAI-SER + HuggingFaceTB/SmolLM2-1.7B-Instruct + HuggingFaceTB/SmolLM2-135M-Instruct + HuggingFaceTB/SmolLM2-360M-Instruct + HuggingFaceTB/SmolLM3-3B + huihui-ai/DeepSeek-R1-Distill-Llama-70B-abliterated + huihui-ai/DeepSeek-R1-Distill-Llama-8B-abliterated + huihui-ai/DeepSeek-R1-Distill-Qwen-14B-abliterated-v2 + huihui-ai/DeepSeek-R1-Distill-Qwen-32B-abliterated + huihui-ai/DeepSeek-R1-Distill-Qwen-7B-abliterated-v2 + huihui-ai/Qwen2.5-14B-Instruct-1M-abliterated + ibm-granite/granite-3.1-8b-instruct + Ihor/Text2Graph-R1-Qwen2.5-0.5b + inclusionAI/Ling-Coder-lite indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 + Infinigence/Megrez-3B-Instruct + inflatebot/MN-12B-Mag-Mell-R1 + INSAIT-Institute/BgGPT-Gemma-2-27B-IT-v1.0 + jinaai/ReaderLM-v2 + Josephgflowers/TinyLlama_v1.1_math_code-world-test-1 + kms7530/chemeng_qwen-math-7b_24_1_100_1_nonmath + knifeayumu/Cydonia-v1.3-Magnum-v4-22B + langgptai/qwen1.5-7b-chat-sa-v0.1 + LatitudeGames/Wayfarer-12B + llava-hf/llava-1.5-7b-hf + LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct + LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct + lightblue/DeepSeek-R1-Distill-Qwen-7B-Japanese + Magpie-Align/Llama-3-8B-Magpie-Align-v0.1 + Magpie-Align/Llama-3.1-8B-Magpie-Align-v0.1 mattshumer/Reflection-Llama-3.1-70B + MaziyarPanahi/calme-3.2-instruct-78b + meetkai/functionary-medium-v3.1 meetkai/functionary-medium-v3.2 - meta-llama/Llama-3.1-8B-Instruct # Gated - meta-llama/Llama-3.2-3B-Instruct # Gated - meta-llama/Llama-3.3-70B-Instruct # Gated - meta-llama/Meta-Llama-3.1-8B-Instruct # Gated + meta-llama/Llama-2-7b-chat-hf + meta-llama/Llama-3.1-8B-Instruct + meta-llama/Llama-3.2-1B-Instruct + meta-llama/Llama-3.2-3B-Instruct + meta-llama/Llama-3.3-70B-Instruct + meta-llama/Meta-Llama-3-8B-Instruct + meta-llama/Meta-Llama-3.1-8B-Instruct microsoft/Phi-3-medium-4k-instruct microsoft/Phi-3-mini-4k-instruct microsoft/Phi-3-small-8k-instruct microsoft/Phi-3.5-mini-instruct microsoft/Phi-3.5-vision-instruct - mistralai/Mistral-7B-Instruct-v0.2 # Gated - mistralai/Mistral-Large-Instruct-2407 # Gated - mistralai/Mistral-Large-Instruct-2411 # Gated - mistralai/Mistral-Nemo-Instruct-2407 # Gated - mistralai/Mixtral-8x7B-Instruct-v0.1 # Gated + microsoft/phi-4 + migtissera/Tess-3-Mistral-Nemo-12B + MiniMaxAI/MiniMax-Text-01 + MiniMaxAI/MiniMax-VL-01 + ministral/Ministral-3b-instruct + mistralai/Codestral-22B-v0.1 + mistralai/Mistral-7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + mistralai/Mistral-7B-Instruct-v0.3 + mistralai/Mistral-Large-Instruct-2407 + mistralai/Mistral-Large-Instruct-2411 + mistralai/Mistral-Nemo-Instruct-2407 + mistralai/Mistral-Small-24B-Instruct-2501 + mistralai/Mixtral-8x7B-Instruct-v0.1 + mkurman/Qwen2.5-14B-DeepSeek-R1-1M mlabonne/AlphaMonarch-7B + mlx-community/Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32 + mlx-community/Qwen2.5-VL-7B-Instruct-8bit + mobiuslabsgmbh/DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1 + NaniDAO/deepseek-r1-qwen-2.5-32B-ablated + netcat420/MFANNv0.20 + netcat420/MFANNv0.24 + netease-youdao/Confucius-o1-14B NexaAIDev/Octopus-v2 NousResearch/Hermes-2-Pro-Llama-3-8B NousResearch/Hermes-2-Pro-Mistral-7B NousResearch/Hermes-3-Llama-3.1-70B + NovaSky-AI/Sky-T1-32B-Flash + NovaSky-AI/Sky-T1-32B-Preview + nvidia/AceMath-7B-RM + nvidia/Eagle2-1B + nvidia/Eagle2-9B nvidia/Llama-3.1-Nemotron-70B-Instruct-HF + OnlyCheeini/greesychat-turbo + onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX + open-thoughts/OpenThinker-7B + openbmb/MiniCPM3-4B openchat/openchat-3.5-0106 + Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2 OrionStarAI/Orion-14B-Chat + pankajmathur/orca_mini_v6_8b + PowerInfer/SmallThinker-3B-Preview + PrimeIntellect/INTELLECT-1-Instruct + princeton-nlp/Mistral-7B-Base-SFT-RDPO + princeton-nlp/Mistral-7B-Instruct-DPO + princeton-nlp/Mistral-7B-Instruct-RDPO + prithivMLmods/Bellatrix-Tiny-1.5B-R1 + prithivMLmods/Bellatrix-Tiny-1B-R1 + prithivMLmods/Bellatrix-Tiny-1B-v3 + prithivMLmods/Bellatrix-Tiny-3B-R1 + prithivMLmods/Blaze-14B-xElite + prithivMLmods/Calcium-Opus-14B-Elite2-R1 + prithivMLmods/Calme-Ties-78B + prithivMLmods/Calme-Ties2-78B + prithivMLmods/Calme-Ties3-78B + prithivMLmods/ChemQwen2-vL + prithivMLmods/GWQ2b + prithivMLmods/LatexMind-2B-Codec + prithivMLmods/Llama-3.2-6B-AlgoCode + prithivMLmods/Megatron-Opus-14B-Exp + prithivMLmods/Megatron-Opus-14B-Stock + prithivMLmods/Megatron-Opus-7B-Exp + prithivMLmods/Omni-Reasoner-Merged + prithivMLmods/Omni-Reasoner4-Merged + prithivMLmods/Primal-Opus-14B-Optimus-v1 + prithivMLmods/Qwen-7B-Distill-Reasoner + prithivMLmods/Qwen2.5-1.5B-DeepSeek-R1-Instruct + prithivMLmods/Qwen2.5-14B-DeepSeek-R1-1M + prithivMLmods/Qwen2.5-32B-DeepSeek-R1-Instruct + prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M + prithivMLmods/QwQ-Math-IO-500M + prithivMLmods/Triangulum-v2-10B + Qwen/QVQ-72B-Preview + Qwen/Qwen1.5-7B-Chat Qwen/Qwen2-7B-Instruct + Qwen/Qwen2-VL-72B-Instruct Qwen/Qwen2-VL-7B-Instruct + Qwen/Qwen2.5-0.5B + Qwen/Qwen2.5-1.5B-Instruct + Qwen/Qwen2.5-14B + Qwen/Qwen2.5-14B-Instruct-1M + Qwen/Qwen2.5-32B + Qwen/Qwen2.5-32B-Instruct + Qwen/Qwen2.5-3B-Instruct + Qwen/Qwen2.5-72B-Instruct + Qwen/Qwen2.5-7B Qwen/Qwen2.5-7B-Instruct + Qwen/Qwen2.5-7B-Instruct-1M + Qwen/Qwen2.5-Coder-32B-Instruct + Qwen/Qwen2.5-Coder-7B-Instruct + Qwen/Qwen2.5-Math-1.5B Qwen/Qwen2.5-Math-7B-Instruct + Qwen/Qwen2.5-VL-3B-Instruct + Qwen/Qwen2.5-VL-72B-Instruct + Qwen/Qwen2.5-VL-7B-Instruct Qwen/QwQ-32B-Preview + rubenroy/Zurich-14B-GCv2-5m + rubenroy/Zurich-7B-GCv2-5m + RWKV-Red-Team/ARWKV-7B-Preview-0.1 + SakanaAI/TinySwallow-1.5B + SakanaAI/TinySwallow-1.5B-Instruct + Sao10K/70B-L3.3-Cirrus-x1 + SentientAGI/Dobby-Mini-Leashed-Llama-3.1-8B + SentientAGI/Dobby-Mini-Unhinged-Llama-3.1-8B + silma-ai/SILMA-Kashif-2B-Instruct-v1.0 + simplescaling/s1-32B + sometimesanotion/Lamarck-14B-v0.7 + sonthenguyen/zephyr-sft-bnb-4bit-DPO-mtbr-180steps + Steelskull/L3.3-Damascus-R1 + Steelskull/L3.3-MS-Nevoria-70b + Steelskull/L3.3-Nevoria-R1-70b + sthenno/tempesthenno-icy-0130 + sumink/qwft + Tarek07/Progenitor-V1.1-LLaMa-70B teknium/OpenHermes-2.5-Mistral-7B TheBloke/FusionNet_34Bx2_MoE-AWQ + thirdeyeai/elevate360m + THUDM/glm-4-9b-chat + THUDM/glm-edge-1.5b-chat + tiiuae/Falcon3-10B-Instruct + TinyLlama/TinyLlama-1.1B-Chat-v1.0 + UCLA-AGI/Mistral7B-PairRM-SPPO-Iter3 + unsloth/DeepSeek-R1-Distill-Llama-8B + unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit + unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit + upstage/solar-pro-preview-instruct + ValiantLabs/Llama3.1-8B-Enigma + xwen-team/Xwen-72B-Chat + xwen-team/Xwen-7B-Chat + Qwen/Qwen3-4B + Qwen/Qwen3-235B-A22B-Instruct-2507 + Qwen/Qwen3-235B-A22B-Thinking-2507 + Qwen/Qwen3-Coder-30B-A3B-Instruct + Qwen/QwQ-32B # Broken, TODO: - # meetkai/functionary-medium-v3.1 # jinja2 expectation is computed w/ wrong escapes - # fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7 # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 + # Almawave/Velvet-14B + # deepseek-ai/DeepSeek-R1 + # deepseek-ai/DeepSeek-R1-Zero + # fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7 + # HuggingFaceTB/SmolVLM-256M-Instruct + # HuggingFaceTB/SmolVLM-500M-Instruct + # HuggingFaceTB/SmolVLM-Instruct + # meta-llama/Llama-3.2-11B-Vision-Instruct + # unsloth/DeepSeek-R1 ) -if(NOT WIN32) - list(APPEND MODEL_IDS - # Needs investigation - deepseek-ai/deepseek-coder-33b-instruct - deepseek-ai/DeepSeek-Coder-V2-Instruct - deepseek-ai/DeepSeek-V2.5 - deepseek-ai/DeepSeek-R1-Distill-Llama-8B - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B - deepseek-ai/DeepSeek-R1-Distill-Qwen-32B +if(WIN32) + list(REMOVE_ITEM MODEL_IDS + # Needs investigation (https://github.com/google/minja/issues/40) + CohereForAI/c4ai-command-r7b-12-2024 ) endif() @@ -132,8 +367,8 @@ foreach(test_case ${CHAT_TEMPLATE_TEST_CASES}) separate_arguments(test_args UNIX_COMMAND "${test_case}") list(GET test_args -1 last_arg) string(REGEX REPLACE "^[^ ]+/([^ /\\]+)\\.[^.]+$" "\\1" test_name "${last_arg}") - add_test(NAME ${test_name} COMMAND $ ${test_args}) - set_tests_properties(${test_name} PROPERTIES SKIP_RETURN_CODE 127) + add_test(NAME test-supported-template-${test_name} COMMAND $ ${test_args}) + set_tests_properties(test-supported-template-${test_name} PROPERTIES SKIP_RETURN_CODE 127) endforeach() if (MINJA_FUZZTEST_ENABLED) diff --git a/tests/contexts/simple.json b/tests/contexts/simple.json index 560f92f..5e89f22 100644 --- a/tests/contexts/simple.json +++ b/tests/contexts/simple.json @@ -11,5 +11,6 @@ ], "add_generation_prompt": true, "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "tools_in_user_message": false } diff --git a/tests/contexts/system.json b/tests/contexts/system.json index 4d72972..7cbc5c2 100644 --- a/tests/contexts/system.json +++ b/tests/contexts/system.json @@ -15,5 +15,6 @@ ], "add_generation_prompt": true, "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "tools_in_user_message": false } diff --git a/tests/contexts/tool_use.json b/tests/contexts/tool_use.json index 4920d19..cca70cb 100644 --- a/tests/contexts/tool_use.json +++ b/tests/contexts/tool_use.json @@ -88,6 +88,7 @@ "add_generation_prompt": true, "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>", + "tools_in_user_message": false, "builtin_tools": [ "wolfram_alpha", "brave_search" @@ -96,72 +97,72 @@ "todays_date": "2024-09-03", "tools": [ { - "type": "function", "function": { - "name": "ipython", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "name": "ipython", "parameters": { - "type": "object", "properties": { "code": { - "type": "string", - "description": "The code to run in the ipython interpreter." + "description": "The code to run in the ipython interpreter.", + "type": "string" } }, - "required": ["code"] + "required": ["code"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "brave_search", "description": "Executes a web search with Brave.", + "name": "brave_search", "parameters": { - "type": "object", "properties": { "query": { - "type": "string", - "description": "The query to search for." + "description": "The query to search for.", + "type": "string" } }, - "required": ["query"] + "required": ["query"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "wolfram_alpha", "description": "Executes a query with Wolfram Alpha.", + "name": "wolfram_alpha", "parameters": { - "type": "object", "properties": { "query": { - "type": "string", - "description": "The query to execute." + "description": "The query to execute.", + "type": "string" } }, - "required": ["query"] + "required": ["query"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "test", "description": "Runs a test.", + "name": "test", "parameters": { - "type": "object", "properties": { "condition": { - "type": "boolean", - "description": "The condition to test." + "description": "The condition to test.", + "type": "boolean" } }, - "required": ["condition"] + "required": ["condition"], + "type": "object" } - } + }, + "type": "function" } ] } \ No newline at end of file diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp index 225581f..458f9b9 100644 --- a/tests/test-capabilities.cpp +++ b/tests/test-capabilities.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "chat-template.hpp" +#include "minja/chat-template.hpp" #include #include @@ -75,6 +75,30 @@ TEST(CapabilitiesTest, Gemma7b) { EXPECT_FALSE(caps.requires_typed_content); } +TEST(CapabilitiesTest, QwQ32B) { + auto caps = get_caps("tests/Qwen-QwQ-32B.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, Qwen3Coder) { + auto caps = get_caps("tests/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + #ifndef _WIN32 TEST(CapabilitiesTest, DeepSeekR1Distill) { @@ -141,7 +165,7 @@ TEST(CapabilitiesTest, MetaLlama3_3_70BInstruct) { TEST(CapabilitiesTest, MiniMaxAIText01) { auto caps = get_caps("tests/MiniMaxAI-MiniMax-Text-01.jinja"); EXPECT_TRUE(caps.supports_system_role); - EXPECT_FALSE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tools); EXPECT_FALSE(caps.supports_tool_calls); EXPECT_FALSE(caps.supports_tool_responses); EXPECT_FALSE(caps.supports_parallel_tool_calls); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index f6110cd..0831275 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -1,3 +1,4 @@ + /* Copyright 2024 Google LLC @@ -6,151 +7,63 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "chat-template.hpp" +#include "minja/chat-template.hpp" +#include "gtest/gtest.h" +#include +#include -#include -#include -#include -#include #include -#include - -#undef NDEBUG -#include +#include +#include -using json = nlohmann::ordered_json; +using namespace minja; +using namespace testing; + +static std::string render_python(const std::string & template_str, const chat_template_inputs & inputs) { + json bindings = inputs.extra_context; + bindings["messages"] = inputs.messages; + bindings["tools"] = inputs.tools; + bindings["add_generation_prompt"] = inputs.add_generation_prompt; + json data { + {"template", template_str}, + {"bindings", bindings}, + {"options", { + {"trim_blocks", true}, + {"lstrip_blocks", true}, + {"keep_trailing_newline", false}, + }}, + }; + { + std::ofstream of("data.json"); + of << data.dump(2); + of.close(); + } -template -static void assert_equals(const T &expected, const T &actual){ - if (expected != actual) { - std::cerr << "Expected: " << expected << "\n\n"; - std::cerr << "Actual: " << actual << "\n\n"; - auto i_divergence = std::min(expected.size(), actual.size()); - for (size_t i = 0; i < i_divergence; i++) { - if (expected[i] != actual[i]) { - i_divergence = i; - break; - } - } - std::cerr << "Divergence at index " << i_divergence << "\n\n"; - std::cerr << "Expected suffix: " << expected.substr(i_divergence) << "\n\n"; - std::cerr << "Actual suffix: " << actual.substr(i_divergence) << "\n\n"; + auto pyExeEnv = getenv("PYTHON_EXECUTABLE"); + std::string pyExe = pyExeEnv ? pyExeEnv : "python3"; - std::cerr << std::flush; - throw std::runtime_error("Test failed"); + std::remove("out.txt"); + auto res = std::system((pyExe + " -m scripts.render data.json out.txt").c_str()); + if (res != 0) { + throw std::runtime_error("Failed to run python script with data: " + data.dump(2)); } -} -static std::string read_file(const std::string &path) { - std::ifstream fs(path, std::ios_base::binary); - if (!fs.is_open()) { - throw std::runtime_error("Failed to open file: " + path); - } - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - std::string out; - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + std::ifstream f("out.txt"); + std::string out((std::istreambuf_iterator(f)), std::istreambuf_iterator()); return out; } -#ifndef _WIN32 -static json caps_to_json(const minja::chat_template_caps &caps) { - return { - {"supports_tools", caps.supports_tools}, - {"supports_tool_calls", caps.supports_tool_calls}, - {"supports_tool_responses", caps.supports_tool_responses}, - {"supports_system_role", caps.supports_system_role}, - {"supports_parallel_tool_calls", caps.supports_parallel_tool_calls}, - {"supports_tool_call_id", caps.supports_tool_call_id}, - {"requires_object_arguments", caps.requires_object_arguments}, - // {"requires_non_null_content", caps.requires_non_null_content}, - {"requires_typed_content", caps.requires_typed_content}, - }; +static std::string render(const std::string & template_str, const chat_template_inputs & inputs, const chat_template_options & opts) { + if (getenv("USE_JINJA2")) { + return render_python(template_str, inputs); + } + chat_template tmpl( + template_str, + "", + ""); + return tmpl.apply(inputs, opts); } -#endif - -int main(int argc, char *argv[]) { - if (argc != 5) - { - std::cerr << "Usage: " << argv[0] << " " << std::endl; - for (int i = 0; i < argc; i++) - { - std::cerr << "argv[" << i << "] = " << argv[i] << std::endl; - } - return 1; - } - - try { - std::string tmpl_file = argv[1]; - std::string caps_file = argv[2]; - std::string ctx_file = argv[3]; - std::string golden_file = argv[4]; - - auto tmpl_str = read_file(tmpl_file); - - if (ctx_file == "n/a") - { - std::cout << "# Skipping template: " << tmpl_file << "\n" << tmpl_str << std::endl; - return 127; - } - - std::cout << "# Testing template: " << tmpl_file << std::endl - << "# With caps: " << caps_file << std::endl - << "# With context: " << ctx_file << std::endl - << "# Against golden file: " << golden_file << std::endl - << std::flush; - - auto ctx = json::parse(read_file(ctx_file)); - minja::chat_template tmpl( - tmpl_str, - ctx.at("bos_token"), - ctx.at("eos_token")); - - std::string expected; - try { - expected = minja::normalize_newlines(read_file(golden_file)); - } catch (const std::exception &e) { - std::cerr << "Failed to read golden file: " << golden_file << std::endl; - std::cerr << e.what() << std::endl; - return 1; - } - - struct minja::chat_template_inputs inputs; - inputs.messages = ctx.at("messages"); - inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); - inputs.add_generation_prompt = ctx.at("add_generation_prompt"); - if (ctx.contains("tools")) { - inputs.extra_context = json { - {"builtin_tools", { - {"wolfram_alpha", "brave_search"} - }}, - }; - } - std::string actual; - try { - actual = tmpl.apply(inputs); - } catch (const std::exception &e) { - std::cerr << "Error applying template: " << e.what() << std::endl; - return 1; - } - - assert_equals(expected, actual); - - // Some unresolved CRLF issues again with the goldens on Windows. -#ifndef _WIN32 - // Checks that the Python & C++ capability detection codes are in sync. - auto expected_caps = minja::normalize_newlines(read_file(caps_file)); - auto caps = caps_to_json(tmpl.original_caps()).dump(2); - assert_equals(expected_caps, caps); -#endif - - std::cout << "Test passed successfully." << std::endl; - return 0; - } catch (const std::exception &e) { - std::cerr << "Test failed: " << e.what() << std::endl; - return 1; - } +TEST(ChatTemplateTest, SimpleCases) { + EXPECT_THAT(render("{{ strftime_now('%Y-%m-%d %H:%M:%S') }}", {}, {}), MatchesRegex(R"([0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2})")); } diff --git a/tests/test-fuzz.cpp b/tests/test-fuzz.cpp index c256079..7169bce 100644 --- a/tests/test-fuzz.cpp +++ b/tests/test-fuzz.cpp @@ -9,8 +9,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp new file mode 100644 index 0000000..5bc1226 --- /dev/null +++ b/tests/test-polyfills.cpp @@ -0,0 +1,579 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#include "minja/minja.hpp" +#include +#include + +#include +#include +#include +#include "minja/chat-template.hpp" + +using namespace minja; + +static std::string read_file(const std::string &path) +{ + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) + { + throw std::runtime_error("Failed to open file: " + path); + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + +#define TEMPLATE_CHATML \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + + +#define TEMPLATE_CHATML_NO_SYSTEM \ + "{%- for message in messages -%}\n" \ + " {%- if message.role == 'system' -%}\n" \ + " {{- raise_exception('System role not supported') -}}\n" \ + " {%- endif -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + + +#define TEMPLATE_DUMMY \ + "{%- for tool in tools -%}\n" \ + " {{- 'tool: ' + (tool | tojson(indent=2)) + '\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- for message in messages -%}\n" \ + " {{- 'message: ' + (message | tojson(indent=2)) + '\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- 'message: ' -}}\n" \ + "{%- endif -%}" + + +const json message_user_text { + { "role", "user" }, + { "content", "I need help" }, +}; +const json message_assistant_text { + { "role", "assistant" }, + { "content", "Hello, world!" }, +}; +const json message_system { + { "role", "system" }, + { "content", "I am The System!" }, +}; +const json tool_calls = json::array({{ + { "type", "function" }, + { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, +}}); + +const json message_assistant_call { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, +}; +const json message_assistant_call_id { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "123456789"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } +}; +const json message_assistant_call_idx { + { "role", "assistant"}, + { "content", {}}, + { "tool_plan", "I'm not so sure"}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "0"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } +}; +const json message_tool { + { "role", "tool"}, + { "content", { + {"result", 123}, + }}, + { "tool_call_id", "123456789"}, +}; + +const auto special_function_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + } + } +})"); + +auto ThrowsWithSubstr = [](const std::string & expected_substr) { + return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); +}; + +static chat_template_options options_no_polyfills() { + chat_template_options opts; + opts.apply_polyfills = false; + opts.polyfill_system_role = false; + opts.polyfill_tools = false; + opts.polyfill_tool_call_examples = false; + opts.polyfill_tool_calls = false; + opts.polyfill_tool_responses = false; + opts.polyfill_object_arguments = false; + opts.polyfill_typed_content = false; + return opts; +}; + +TEST(PolyfillTest, NoPolyFill) { + chat_template tmpl(TEMPLATE_CHATML, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text}); + + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs, options_no_polyfills())); + + inputs.add_generation_prompt = false; + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n", + tmpl.apply(inputs, options_no_polyfills())); + + inputs.messages = json::array({message_user_text, message_assistant_text}); + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n" + "Hello, world!<|im_end|>\n", + tmpl.apply(inputs, options_no_polyfills())); +} + +TEST(PolyfillTest, SystemRoleSupported) { + chat_template chatml(TEMPLATE_CHATML, "", ""); + chat_template dummy(TEMPLATE_DUMMY, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_system, message_user_text}); + + EXPECT_EQ( + "<|im_start|>system\n" + "I am The System!<|im_end|>\n" + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + chatml.apply(inputs)); + EXPECT_EQ( + "message: {\n" + " \"role\": \"system\",\n" + " \"content\": \"I am The System!\"\n" + "}\n" + "message: {\n" + " \"role\": \"user\",\n" + " \"content\": \"I need help\"\n" + "}\n" + "message: ", + dummy.apply(inputs)); +} + +TEST(PolyfillTest, SystemRolePolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_system, message_user_text}); + + EXPECT_THAT( + [&]() { tmpl.apply(inputs, options_no_polyfills()); }, + ThrowsWithSubstr("System role not supported")); + + EXPECT_EQ( + "<|im_start|>user\n" + "I am The System!\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolCallSupported) { + chat_template tmpl(TEMPLATE_DUMMY, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text, message_assistant_call_id}); + + EXPECT_EQ( + "message: {\n" + " \"role\": \"user\",\n" + " \"content\": \"I need help\"\n" + "}\n" + "message: {\n" + " \"role\": \"assistant\",\n" + " \"content\": null,\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " }\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}\n" + "message: ", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolCallPolyfill) { + chat_template tmpl(TEMPLATE_CHATML, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text, message_assistant_call_id}); + + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n" + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolsPolyfill) { + chat_template tmpl(TEMPLATE_CHATML, "", "<|im_end|>"); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text}); + inputs.tools = json::array({special_function_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You can call any of the following tools to satisfy the user's requests: [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]\n" + "\n" + "Example tool call syntax:\n" + "\n" + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"tool_name\",\n" + " \"arguments\": {\n" + " \"arg1\": \"some_value\"\n" + " },\n" + " \"id\": \"call_1___\"\n" + " }\n" + " ]\n" + "}\n\n<|im_end|>\n" + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolSupported) { + chat_template tmpl(TEMPLATE_DUMMY, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "message: {\n" + " \"role\": \"tool\",\n" + " \"content\": {\n" + " \"result\": 123\n" + " },\n" + " \"tool_call_id\": \"123456789\"\n" + "}\n" + "message: ", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>user\n{\n" + " \"tool_response\": {\n" + " \"content\": {\n" + " \"result\": 123\n" + " },\n" + " \"tool_call_id\": \"123456789\"\n" + " }\n" + "}<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +#ifndef _WIN32 +TEST(ToolTest, DeepSeekR1) { + chat_template tmpl(read_file("tests/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|tool▁outputs▁begin|><|tool▁output▁begin|>{'result': 123}<|tool▁output▁end|><|tool▁outputs▁end|>", + tmpl.apply(inputs)); +} + +TEST(ToolTest, CommandR7b) { + chat_template tmpl(read_file("tests/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n" + "You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n" + "\n" + "Your information cutoff date is June 2024.\n" + "\n" + "You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n" + "# Default Preamble\n" + "The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n" + "- Your name is Command.\n" + "- You are a large language model built by Cohere.\n" + "- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n" + "- If the input is ambiguous, ask clarifying follow-up questions.\n" + "- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n" + "- Use LaTeX to generate mathematical notation for complex equations.\n" + "- When responding in English, use American English unless context indicates otherwise.\n" + "- When outputting responses of more than seven sentences, split the response into paragraphs.\n" + "- Prefer the active voice.\n" + "- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n" + "- Use gender-neutral pronouns for unspecified persons.\n" + "- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n" + "- Use the third person when asked to write a summary.\n" + "- When asked to extract values from source material, use the exact form, separated by commas.\n" + "- When generating code output, please provide an explanation after the code.\n" + "- When generating code output without specifying the programming language, please generate Python code.\n" + "- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n" + " {\n" + " \"tool_call_id\": \"\",\n" + " \"results\": {\n" + " \"0\": {\"result\": 123}\n" + " },\n" + " \"is_error\": null\n" + " }\n" + "]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + tmpl.apply(inputs)); +} +#endif // NOT _WIN32 + +TEST(ToolTest, MistralNemo) { + chat_template tmpl(read_file("tests/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "[TOOL_RESULTS]{\"content\": {'result': 123}, \"call_id\": \"123456789\"}[/TOOL_RESULTS]", + tmpl.apply(inputs)); +} + +TEST(ToolTest, NousResearchHermes3) { + chat_template tmpl(read_file("tests/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n" + "For each function call return a json object with function name and arguments within XML tags as follows:\n" + "\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n" + "\n" + "{'result': 123}\n" + "<|im_end|><|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, NousResearchHermes2) { + chat_template tmpl(read_file("tests/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n" + "For each function call return a json object with function name and arguments within XML tags as follows:\n" + "\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n" + "\n" + "{'result': 123}\n" + "<|im_end|><|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, Llama3_3) { + chat_template tmpl(read_file("tests/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "Cutting Knowledge Date: December 2023\n" + "Today Date: 26 Jul 2024\n" + "\n" + "<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n" + "\n" + "{\"result\": 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, MeetkaiFunctionary3_1) { + chat_template tmpl(read_file("tests/meetkai-functionary-medium-v3.1.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "\n" + "Cutting Knowledge Date: December 2023\n" + "\n" + "<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n" + "\n" + "{'result': 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, MeetkaiFunctionary3_2) { + chat_template tmpl(read_file("tests/meetkai-functionary-medium-v3.2.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "You are capable of executing available function(s) if required.\n" + "Only execute function(s) when absolutely necessary.\n" + "Ask for the required input to:recipient==all\n" + "Use JSON for function arguments.\n" + "Respond in this format:\n" + ">>>${recipient}\n" + "${content}\n" + "Available functions:\n" + "// Supported function definitions that should be called when necessary.\n" + "namespace functions {\n" + "\n" + "} // namespace functions<|eot_id|><|start_header_id|>tool<|end_header_id|>\n" + "\n" + "{'result': 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n" + ">>>", + tmpl.apply(inputs)); +} + +/* +https://github.com/google/minja/issues/7 +TEST(ToolTest, FirefunctionV2) { + chat_template tmpl(read_file("tests/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>tool\n" + "{\n" + " \"result\": 123\n" + "}\n" + "<|im_end|>", + tmpl.apply(inputs)); +} +*/ diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp new file mode 100644 index 0000000..db23a4a --- /dev/null +++ b/tests/test-supported-template.cpp @@ -0,0 +1,178 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#include "minja/chat-template.hpp" + +#include +#include +#include +#include +#include +#include + +#undef NDEBUG +#include + +#define TEST_DATE (getenv("TEST_DATE") ? getenv("TEST_DATE") : "2024-07-26") + +using json = nlohmann::ordered_json; + +template +static void assert_equals(const T &expected, const T &actual){ + if (expected != actual) { + std::cerr << "Expected: " << expected << "\n\n"; + std::cerr << "Actual: " << actual << "\n\n"; + auto i_divergence = std::min(expected.size(), actual.size()); + for (size_t i = 0; i < i_divergence; i++) { + if (expected[i] != actual[i]) { + i_divergence = i; + break; + } + } + std::cerr << "Divergence at index " << i_divergence << "\n\n"; + std::cerr << "Expected suffix: " << expected.substr(i_divergence) << "\n\n"; + std::cerr << "Actual suffix: " << actual.substr(i_divergence) << "\n\n"; + + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +static std::string read_file(const std::string &path) { + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + +static void write_file(const std::string &path, const std::string &content) { + std::ofstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + fs.write(content.data(), content.size()); +} + +#ifndef _WIN32 +static json caps_to_json(const minja::chat_template_caps &caps) { + return { + {"supports_tools", caps.supports_tools}, + {"supports_tool_calls", caps.supports_tool_calls}, + {"supports_tool_responses", caps.supports_tool_responses}, + {"supports_system_role", caps.supports_system_role}, + {"supports_parallel_tool_calls", caps.supports_parallel_tool_calls}, + {"supports_tool_call_id", caps.supports_tool_call_id}, + {"requires_object_arguments", caps.requires_object_arguments}, + // {"requires_non_null_content", caps.requires_non_null_content}, + {"requires_typed_content", caps.requires_typed_content}, + }; +} +#endif + +int main(int argc, char *argv[]) { + if (argc != 5) + { + std::cerr << "Usage: " << argv[0] << " " << "\n"; + for (int i = 0; i < argc; i++) + { + std::cerr << "argv[" << i << "] = " << argv[i] << "\n"; + } + return 1; + } + + try { + std::string tmpl_file = argv[1]; + std::string caps_file = argv[2]; + std::string ctx_file = argv[3]; + std::string golden_file = argv[4]; + + auto tmpl_str = read_file(tmpl_file); + + if (ctx_file == "n/a") + { + std::cout << "# Skipping template: " << tmpl_file << "\n" << tmpl_str << "\n"; + return 127; + } + + std::cout << "# Testing template:\n" + << "# ./build/bin/test-supported-template " << json::array({tmpl_file, caps_file, ctx_file, golden_file}).dump() << std::endl + << std::flush; + + auto ctx = json::parse(read_file(ctx_file)); + + minja::chat_template tmpl( + tmpl_str, + ctx.at("bos_token"), + ctx.at("eos_token")); + + std::string expected; + try { + expected = minja::normalize_newlines(read_file(golden_file)); + } catch (const std::exception &e) { + std::cerr << "Failed to read golden file: " << golden_file << "\n"; + std::cerr << e.what() << "\n"; + return 1; + } + + struct minja::chat_template_inputs inputs; + inputs.messages = ctx.at("messages"); + ctx.erase("messages"); + + if (ctx.contains("tools")) { + inputs.tools = ctx.at("tools"); + ctx.erase("tools"); + } + inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + ctx.erase("add_generation_prompt"); + + std::istringstream ss(TEST_DATE); + std::tm tm = {}; + ss >> std::get_time(&tm, "%Y-%m-%d"); + inputs.now = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + + inputs.extra_context = ctx; + + std::string actual; + try { + actual = tmpl.apply(inputs); + } catch (const std::exception &e) { + std::cerr << "Error applying template: " << e.what() << "\n"; + return 1; + } + + if (expected != actual) { + if (getenv("WRITE_GOLDENS")) { + write_file(golden_file, actual); + std::cerr << "Updated golden file: " << golden_file << "\n"; + } else { + assert_equals(expected, actual); + } + } + + // Some unresolved CRLF issues again with the goldens on Windows. +#ifndef _WIN32 + // Checks that the Python & C++ capability detection codes are in sync. + auto expected_caps = minja::normalize_newlines(read_file(caps_file)); + auto caps = caps_to_json(tmpl.original_caps()).dump(2); + assert_equals(expected_caps, caps); +#endif + + std::cout << "Test passed successfully." << "\n"; + return 0; + } catch (const std::exception &e) { + std::cerr << "Test failed: " << e.what() << "\n"; + return 1; + } +} diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index ebe5e19..36bdaa3 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -6,7 +6,7 @@ https://opensource.org/licenses/MIT. */ // SPDX-License-Identifier: MIT -#include "minja.hpp" +#include "minja/minja.hpp" #include #include @@ -73,6 +73,36 @@ TEST(SyntaxTest, SimpleCases) { auto ThrowsWithSubstr = [](const std::string & expected_substr) { return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); }; + + EXPECT_EQ("a", render("{{ ' a '.strip() }}", {}, {})); + EXPECT_EQ("a ", render("{{ ' a '.lstrip() }}", {}, {})); + EXPECT_EQ(" a", render("{{ ' a '.rstrip() }}", {}, {})); + EXPECT_EQ("bcXYZab", render("{{ 'abcXYZabc'.strip('ac') }}", {}, {})); + + EXPECT_EQ(R"(["a", "b"])", render("{{ 'a b'.split(' ') | tojson }}", {}, {})); + + EXPECT_EQ( + "Ok", + render("{{ 'ok'.capitalize() }}", {}, {})); + EXPECT_EQ("aouiXYZaouiXYZaoui", + render("{{ 'abcXYZabcXYZabc'.replace('bc', 'oui') }}", {}, {})); + EXPECT_EQ("okXYZokXYZabc", + render("{{ 'abcXYZabcXYZabc'.replace('abc', 'ok', 2) }}", {}, {})); + EXPECT_EQ("abcXYZabcXYZabc", + render("{{ 'abcXYZabcXYZabc'.replace('def', 'ok') }}", {}, {})); + + EXPECT_EQ("HELLO WORLD", render("{{ 'hello world'.upper() }}", {}, {})); + EXPECT_EQ("MIXED", render("{{ 'MiXeD'.upper() }}", {}, {})); + EXPECT_EQ("", render("{{ ''.upper() }}", {}, {})); + + EXPECT_EQ("hello world", render("{{ 'HELLO WORLD'.lower() }}", {}, {})); + EXPECT_EQ("mixed", render("{{ 'MiXeD'.lower() }}", {}, {})); + EXPECT_EQ("", render("{{ ''.lower() }}", {}, {})); + + EXPECT_EQ( + "ok", + render("{# Hey\nHo #}{#- Multiline...\nComments! -#}{{ 'ok' }}{# yo #}", {}, {})); + EXPECT_EQ( " b", render(R"( {% set _ = 1 %} {% set _ = 2 %}b)", {}, lstrip_trim_blocks)); @@ -114,6 +144,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "abc", render("{{ 'AbC' | lower }}", {}, {})); + EXPECT_EQ( + "ME", + render("{{ 'me' | upper }}", {}, {})); EXPECT_EQ( "the default1", render("{{ foo | default('the default') }}{{ 1 | default('nope') }}", {}, {})); @@ -165,6 +198,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "1", render(R"({{ 1 | safe }})", {}, {})); + EXPECT_EQ( + "True,False", + render(R"({{ 'abc'.startswith('ab') }},{{ ''.startswith('a') }})", {}, {})); EXPECT_EQ( "True,False", render(R"({{ 'abc'.endswith('bc') }},{{ ''.endswith('a') }})", {}, {})); @@ -177,6 +213,14 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "True,False", render(R"({{ 'a' in ["a"] }},{{ 'a' in [] }})", {}, {})); + EXPECT_EQ("True,False", + render(R"({{ 'a' in 'abc' }},{{ 'd' in 'abc' }})", {}, {})); + EXPECT_EQ("False,True", + render(R"({{ 'a' not in 'abc' }},{{ 'd' not in 'abc' }})", {}, {})); + EXPECT_EQ("['a', 'a']", + render(R"({{ ['a', 'b', 'c', 'a'] | select('in', ['a']) | list }})", {}, {})); + EXPECT_EQ("['a', 'b'],[]", + render(R"({{ {'a': 1, 'b': 2}.keys() | list }},{{ {}.keys() | list }})", {}, {})); EXPECT_EQ( R"([{'a': 1}])", render(R"({{ [{"a": 1}, {"a": 2}, {}] | selectattr("a", "equalto", 1) | list }})", {}, {})); @@ -198,6 +242,18 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "False", render(R"({% set foo = true %}{{ not foo is defined }})", {}, {})); + EXPECT_EQ( + "True", + render(R"({% set foo = true %}{{ foo is true }})", {}, {})); + EXPECT_EQ( + "False", + render(R"({% set foo = true %}{{ foo is false }})", {}, {})); + EXPECT_EQ( + "True", + render(R"({% set foo = false %}{{ foo is not true }})", {}, {})); + EXPECT_EQ( + "False", + render(R"({% set foo = false %}{{ foo is not false }})", {}, {})); EXPECT_EQ( R"({"a": "b"})", render(R"({{ {"a": "b"} | tojson }})", {}, {})); @@ -373,10 +429,55 @@ TEST(SyntaxTest, SimpleCases) { {%- endmacro -%} {{- foo() }} {{ foo() -}})", {}, {})); + EXPECT_EQ( + "x,x", + render(R"( + {%- macro test() -%}{{ caller() }},{{ caller() }}{%- endmacro -%} + {%- call test() -%}x{%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "Outer[Inner(X)]", + render(R"( + {%- macro outer() -%}Outer[{{ caller() }}]{%- endmacro -%} + {%- macro inner() -%}Inner({{ caller() }}){%- endmacro -%} + {%- call outer() -%}{%- call inner() -%}X{%- endcall -%}{%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "
  • A
  • B
", + render(R"( + {%- macro test(prefix, suffix) -%}{{ prefix }}{{ caller() }}{{ suffix }}{%- endmacro -%} + {%- set items = ["a", "b"] -%} + {%- call test("
    ", "
") -%} + {%- for item in items -%} +
  • {{ item | upper }}
  • + {%- endfor -%} + {%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "\\n\\nclass A:\\n b: 1\\n c: 2\\n", + render(R"( + {%- macro recursive(obj) -%} + {%- set ns = namespace(content = caller()) -%} + {%- for key, value in obj.items() %} + {%- if value is mapping %} + {%- call recursive(value) -%} + {{ '\\n\\nclass ' + key.title() + ':\\n' }} + {%- endcall -%} + {%- else -%} + {%- set ns.content = ns.content + ' ' + key + ': ' + value + '\\n' -%} + {%- endif -%} + {%- endfor -%} + {{ ns.content }} + {%- endmacro -%} + + {%- call recursive({"a": {"b": "1", "c": "2"}}) -%} + {%- endcall -%} + )", {}, {})); + if (!getenv("USE_JINJA2")) { - EXPECT_EQ( - "[]", - render(R"({{ None | items | list | tojson }})", {}, {})); EXPECT_EQ( "Foo", render(R"({% generation %}Foo{% endgeneration %})", {}, {})); @@ -446,6 +547,15 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "[1, 2, 3][0, 1][1, 2]", render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {})); + EXPECT_EQ( + "123;01;12;0123;0123", + render("{% set x = '0123' %}{{ x[1:] }};{{ x[:2] }};{{ x[1:3] }};{{ x[:] }};{{ x[::] }}", {}, {})); + EXPECT_EQ( + "[3, 2, 1, 0][3, 2, 1][2, 1, 0][2, 1][0, 2][3, 1][2, 0]", + render("{% set x = [0, 1, 2, 3] %}{{ x[::-1] }}{{ x[:0:-1] }}{{ x[2::-1] }}{{ x[2:0:-1] }}{{ x[::2] }}{{ x[::-2] }}{{ x[-2::-2] }}", {}, {})); + EXPECT_EQ( + "3210;321;210;21;02;31;20", + render("{% set x = '0123' %}{{ x[::-1] }};{{ x[:0:-1] }};{{ x[2::-1] }};{{ x[2:0:-1] }};{{ x[::2] }};{{ x[::-2] }};{{ x[-2::-2] }}", {}, {})); EXPECT_EQ( "a", render("{{ ' a ' | trim }}", {}, {})); @@ -486,10 +596,23 @@ TEST(SyntaxTest, SimpleCases) { "", render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {})); + EXPECT_EQ( + "0,1,2,", + render("{% for i in range(10) %}{{ i }},{% if i == 2 %}{% break %}{% endif %}{% endfor %}", {}, {})); + EXPECT_EQ( + "0,2,4,6,8,", + render("{% for i in range(10) %}{% if i % 2 %}{% continue %}{% endif %}{{ i }},{% endfor %}", {}, {})); if (!getenv("USE_JINJA2")) { // TODO: capture stderr from jinja2 and test these. + EXPECT_THAT([]() { render("{{ '' | items }}", {}, {}); }, ThrowsWithSubstr("Can only get item pairs from a mapping")); + EXPECT_THAT([]() { render("{{ [] | items }}", {}, {}); }, ThrowsWithSubstr("Can only get item pairs from a mapping")); + EXPECT_THAT([]() { render("{{ None | items }}", {}, {}); }, ThrowsWithSubstr("Can only get item pairs from a mapping")); + + EXPECT_THAT([]() { render("{% break %}", {}, {}); }, ThrowsWithSubstr("break outside of a loop")); + EXPECT_THAT([]() { render("{% continue %}", {}, {}); }, ThrowsWithSubstr("continue outside of a loop")); + EXPECT_THAT([]() { render("{%- set _ = [].pop() -%}", {}, {}); }, ThrowsWithSubstr("pop from empty list")); EXPECT_THAT([]() { render("{%- set _ = {}.pop() -%}", {}, {}); }, ThrowsWithSubstr("pop")); EXPECT_THAT([]() { render("{%- set _ = {}.pop('foooo') -%}", {}, {}); }, ThrowsWithSubstr("foooo")); @@ -501,6 +624,8 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected elif")); EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfor")); EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfilter")); + EXPECT_THAT([]() { render("{% endmacro %}", {}, {}); }, ThrowsWithSubstr("Unexpected endmacro")); + EXPECT_THAT([]() { render("{% endcall %}", {}, {}); }, ThrowsWithSubstr("Unexpected endcall")); EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for")); @@ -508,6 +633,13 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% if 1 %}{% else %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter")); + EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag")); + EXPECT_THAT([]() { render("{% macro test() %}", {}, {}); }, ThrowsWithSubstr("Unterminated macro")); + EXPECT_THAT([]() { render("{% call test %}", {}, {}); }, ThrowsWithSubstr("Unterminated call")); + + EXPECT_THAT([]() { + render("{%- macro test() -%}content{%- endmacro -%}{%- call test -%}caller_content{%- endcall -%}", {}, {}); + }, ThrowsWithSubstr("Invalid call block syntax - expected function call")); } EXPECT_EQ( diff --git a/tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja b/tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja new file mode 100644 index 0000000..e539012 --- /dev/null +++ b/tests_files/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %}