From f0fd0d442df9e376c70a397d676d348d411e75be Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 25 May 2024 00:24:59 +0200 Subject: [PATCH 1/5] update logging --- CMakeLists.txt | 2 + src/main/cpp/jllama.cpp | 153 ++++++++++++++++-- src/main/cpp/jllama.h | 77 --------- src/main/cpp/server.hpp | 4 - src/main/cpp/utils.hpp | 85 ++++++---- src/main/java/de/kherud/llama/LlamaModel.java | 15 ++ src/main/java/de/kherud/llama/LogLevel.java | 13 ++ .../java/de/kherud/llama/ModelParameters.java | 38 ----- .../java/de/kherud/llama/args/LogFormat.java | 5 +- src/test/java/examples/MainExample.java | 3 +- 10 files changed, 230 insertions(+), 165 deletions(-) delete mode 100644 src/main/cpp/jllama.h create mode 100644 src/main/java/de/kherud/llama/LogLevel.java diff --git a/CMakeLists.txt b/CMakeLists.txt index f45ff00d..9746168b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,8 @@ project(jllama CXX) include(FetchContent) set(BUILD_SHARED_LIBS ON) +set(LLAMA_STATIC OFF) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) option(LLAMA_VERBOSE "llama: verbose output" OFF) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index bd2bda48..4c087bf4 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -1,9 +1,10 @@ #include "jllama.h" -#include "nlohmann/json.hpp" #include "llama.h" +#include "nlohmann/json.hpp" #include "server.hpp" +#include #include // We store some references to Java classes and their fields/methods here to speed up things for later and to fail @@ -12,7 +13,7 @@ namespace { -// JavaVM *g_vm = nullptr; +JavaVM* g_vm = nullptr; // classes jclass c_llama_model = nullptr; @@ -29,6 +30,8 @@ jclass c_integer = nullptr; jclass c_float = nullptr; jclass c_biconsumer = nullptr; jclass c_llama_error = nullptr; +jclass c_log_level = nullptr; +jclass c_log_format = nullptr; jclass c_error_oom = nullptr; // constructors @@ -55,9 +58,22 @@ jfieldID f_model_pointer = nullptr; jfieldID f_task_id = nullptr; jfieldID f_utf_8 = nullptr; jfieldID f_iter_has_next = nullptr; +jfieldID f_log_level_debug = nullptr; +jfieldID f_log_level_info = nullptr; +jfieldID f_log_level_warn = nullptr; +jfieldID f_log_level_error = nullptr; +jfieldID f_log_format_json = nullptr; +jfieldID f_log_format_text = nullptr; // objects jobject o_utf_8 = nullptr; +jobject o_log_level_debug = nullptr; +jobject o_log_level_info = nullptr; +jobject o_log_level_warn = nullptr; +jobject o_log_level_error = nullptr; +jobject o_log_format_json = nullptr; +jobject o_log_format_text = nullptr; +jobject o_log_callback = nullptr; /** * Convert a Java string to a std::string @@ -89,8 +105,40 @@ jbyteArray parse_jbytes(JNIEnv *env, const std::string &string) env->SetByteArrayRegion(bytes, 0, length, reinterpret_cast(string.c_str())); return bytes; } + +/** + * Map a llama.cpp log level to its Java enumeration option. + */ +jobject log_level_to_jobject(ggml_log_level level) +{ + switch (level) + { + case GGML_LOG_LEVEL_ERROR: + return o_log_level_error; + case GGML_LOG_LEVEL_WARN: + return o_log_level_warn; + default: case GGML_LOG_LEVEL_INFO: + return o_log_level_info; + case GGML_LOG_LEVEL_DEBUG: + return o_log_level_debug; + } +} + +/** + * Returns the JNIEnv of the current thread. + */ +JNIEnv* get_jni_env() { + JNIEnv* env = nullptr; + if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + throw std::runtime_error("Thread is not attached to the JVM"); + } + return env; +} } // namespace +bool log_json; +std::function log_callback; + /** * The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`). * `JNI_OnLoad` must return the JNI version needed by the native library. @@ -101,6 +149,7 @@ jbyteArray parse_jbytes(JNIEnv *env, const std::string &string) */ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { + g_vm = vm; JNIEnv *env = nullptr; if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) @@ -123,10 +172,13 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) c_float = env->FindClass("java/lang/Float"); c_biconsumer = env->FindClass("java/util/function/BiConsumer"); c_llama_error = env->FindClass("de/kherud/llama/LlamaException"); + c_log_level = env->FindClass("de/kherud/llama/LogLevel"); + c_log_format = env->FindClass("de/kherud/llama/args/LogFormat"); c_error_oom = env->FindClass("java/lang/OutOfMemoryError"); if (!(c_llama_model && c_llama_iterator && c_standard_charsets && c_output && c_string && c_hash_map && c_map && - c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_error_oom)) + c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level && + c_log_format && c_error_oom)) { goto error; } @@ -145,6 +197,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) c_float = (jclass)env->NewGlobalRef(c_float); c_biconsumer = (jclass)env->NewGlobalRef(c_biconsumer); c_llama_error = (jclass)env->NewGlobalRef(c_llama_error); + c_log_level = (jclass)env->NewGlobalRef(c_log_level); + c_log_format = (jclass)env->NewGlobalRef(c_log_format); c_error_oom = (jclass)env->NewGlobalRef(c_error_oom); // find constructors @@ -182,20 +236,40 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) f_task_id = env->GetFieldID(c_llama_iterator, "taskId", "I"); f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); f_iter_has_next = env->GetFieldID(c_llama_iterator, "hasNext", "Z"); - - if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next)) + f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lde/kherud/llama/LogLevel;"); + f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lde/kherud/llama/LogLevel;"); + f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lde/kherud/llama/LogLevel;"); + f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lde/kherud/llama/LogLevel;"); + f_log_format_json = env->GetStaticFieldID(c_log_format, "JSON", "Lde/kherud/llama/args/LogFormat;"); + f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lde/kherud/llama/args/LogFormat;"); + + if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next && f_log_level_debug && f_log_level_info && + f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text)) { goto error; } o_utf_8 = env->NewStringUTF("UTF-8"); - - if (!(o_utf_8)) + o_log_level_debug = env->GetStaticObjectField(c_log_level, f_log_level_debug); + o_log_level_info = env->GetStaticObjectField(c_log_level, f_log_level_info); + o_log_level_warn = env->GetStaticObjectField(c_log_level, f_log_level_warn); + o_log_level_error = env->GetStaticObjectField(c_log_level, f_log_level_error); + o_log_format_json = env->GetStaticObjectField(c_log_format, f_log_format_json); + o_log_format_text = env->GetStaticObjectField(c_log_format, f_log_format_text); + + if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error && + o_log_format_json && o_log_format_text)) { goto error; } - o_utf_8 = (jclass)env->NewGlobalRef(o_utf_8); + o_utf_8 = env->NewGlobalRef(o_utf_8); + o_log_level_debug = env->NewGlobalRef(o_log_level_debug); + o_log_level_info = env->NewGlobalRef(o_log_level_info); + o_log_level_warn = env->NewGlobalRef(o_log_level_warn); + o_log_level_error = env->NewGlobalRef(o_log_level_error); + o_log_format_json = env->NewGlobalRef(o_log_format_json); + o_log_format_text = env->NewGlobalRef(o_log_format_text); if (env->ExceptionCheck()) { @@ -203,13 +277,15 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) goto error; } + llama_backend_init(); + goto success; error: return JNI_ERR; success: - return JNI_VERSION_1_2; + return JNI_VERSION_1_6; } /** @@ -224,7 +300,7 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { JNIEnv *env = nullptr; - if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) + if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_6)) { return; } @@ -242,9 +318,24 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) env->DeleteGlobalRef(c_float); env->DeleteGlobalRef(c_biconsumer); env->DeleteGlobalRef(c_llama_error); + env->DeleteGlobalRef(c_log_level); + env->DeleteGlobalRef(c_log_level); env->DeleteGlobalRef(c_error_oom); env->DeleteGlobalRef(o_utf_8); + env->DeleteGlobalRef(o_log_level_debug); + env->DeleteGlobalRef(o_log_level_info); + env->DeleteGlobalRef(o_log_level_warn); + env->DeleteGlobalRef(o_log_level_error); + env->DeleteGlobalRef(o_log_format_json); + env->DeleteGlobalRef(o_log_format_text); + + if (o_log_callback != nullptr) + { + env->DeleteGlobalRef(o_log_callback); + } + + llama_backend_free(); } JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams) @@ -277,7 +368,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo params.model_alias = params.model; } - llama_backend_init(); llama_numa_init(params.numa); LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}}); @@ -344,7 +434,17 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); - std::thread t([ctx_server]() { ctx_server->queue_tasks.start_loop(); }); + std::thread t([ctx_server]() { + JNIEnv *env; + jint res = g_vm->GetEnv((void**)&env, JNI_VERSION_1_6); + if (res == JNI_EDETACHED) { + res = g_vm->AttachCurrentThread((void**)&env, nullptr); + if (res != JNI_OK) { + throw std::runtime_error("Failed to attach thread to JVM"); + } + } + ctx_server->queue_tasks.start_loop(); + }); t.detach(); env->SetLongField(obj, f_model_pointer, reinterpret_cast(ctx_server)); @@ -502,8 +602,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobje jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) ctx_server->queue_tasks.terminate(); - // maybe we should keep track how many models were loaded before freeing the backend - llama_backend_free(); delete ctx_server; } @@ -514,3 +612,30 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv * ctx_server->request_cancel(id_task); ctx_server->queue_results.remove_waiting_task_id(id_task); } + +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jclass clazz, jobject log_format, + jobject jcallback) +{ + if (o_log_callback != nullptr) + { + env->DeleteGlobalRef(o_log_callback); + } + + log_json = env->IsSameObject(log_format, o_log_format_json); + + if (jcallback == nullptr) + { + log_callback = nullptr; + } + else + { + o_log_callback = env->NewGlobalRef(jcallback); + log_callback = [](enum ggml_log_level level, const char *text, void *user_data) { + JNIEnv* env = get_jni_env(); + jstring message = env->NewStringUTF(text); + jobject log_level = log_level_to_jobject(level); + env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message); + env->DeleteLocalRef(message); + }; + } +} diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h deleted file mode 100644 index 2c0125ac..00000000 --- a/src/main/cpp/jllama.h +++ /dev/null @@ -1,77 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class de_kherud_llama_LlamaModel */ - -#ifndef _Included_de_kherud_llama_LlamaModel -#define _Included_de_kherud_llama_LlamaModel -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: de_kherud_llama_LlamaModel - * Method: embed - * Signature: (Ljava/lang/String;)[F - */ -JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed - (JNIEnv *, jobject, jstring); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: encode - * Signature: (Ljava/lang/String;)[I - */ -JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode - (JNIEnv *, jobject, jstring); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: requestCompletion - * Signature: (Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion - (JNIEnv *, jobject, jstring); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: receiveCompletion - * Signature: (I)Lde/kherud/llama/LlamaOutput; - */ -JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion - (JNIEnv *, jobject, jint); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: cancelCompletion - * Signature: (I)V - */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion - (JNIEnv *, jobject, jint); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: decodeBytes - * Signature: ([I)[B - */ -JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes - (JNIEnv *, jobject, jintArray); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: loadModel - * Signature: (Ljava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel - (JNIEnv *, jobject, jstring); - -/* - * Class: de_kherud_llama_LlamaModel - * Method: delete - * Signature: ()V - */ -JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete - (JNIEnv *, jobject); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 23aa9057..b111bb7d 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -18,8 +18,6 @@ using json = nlohmann::ordered_json; -bool server_log_json = true; - enum stop_type { STOP_TYPE_FULL, @@ -2579,7 +2577,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params params.input_prefix = json_value(jparams, "input_prefix", default_params.input_prefix); params.input_suffix = json_value(jparams, "input_suffix", default_params.input_suffix); params.antiprompt = json_value(jparams, "antiprompt", default_params.antiprompt); - params.logdir = json_value(jparams, "logdir", default_params.logdir); params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static); params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic); params.logits_file = json_value(jparams, "logits_file", default_params.logits_file); @@ -2594,7 +2591,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap); params.use_mlock = json_value(jparams, "use_mlock", default_params.use_mlock); params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload); - server_log_json = !jparams.contains("log_format") || jparams["log_format"] == "json"; sparams.system_prompt = json_value(jparams, "system_prompt", default_sparams.system_prompt); sparams.chat_template = json_value(jparams, "chat_template", default_sparams.chat_template); diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 57391c40..56f6742a 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -26,23 +26,24 @@ enum error_type ERROR_TYPE_NOT_SUPPORTED, // custom error }; -extern bool server_log_json; +extern bool log_json; +extern std::function log_callback; #if SERVER_VERBOSE #define LOG_VERBOSE(MSG, ...) \ do \ { \ - server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \ + server_log(GGML_LOG_LEVEL_DEBUG, __func__, __LINE__, MSG, __VA_ARGS__); \ } while (0) #else #define LOG_VERBOSE(MSG, ...) #endif -#define LOG_ERROR(MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_ERROR(MSG, ...) server_log(GGML_LOG_LEVEL_ERROR, __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_WARNING(MSG, ...) server_log(GGML_LOG_LEVEL_WARN, __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO(MSG, ...) server_log(GGML_LOG_LEVEL_INFO, __func__, __LINE__, MSG, __VA_ARGS__) -static inline void server_log(const char *level, const char *function, int line, const char *message, +static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, const json &extra); template static T json_value(const json &body, const std::string &key, const T &default_value) @@ -69,50 +70,76 @@ template static T json_value(const json &body, const std::string &k } } -static inline void server_log(const char *level, const char *function, int line, const char *message, const json &extra) +static const char * log_level_to_string(ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_ERROR: + return "ERROR"; + case GGML_LOG_LEVEL_WARN: + return "WARN"; + default: case GGML_LOG_LEVEL_INFO: + return "INFO"; + case GGML_LOG_LEVEL_DEBUG: + return "DEBUG"; + } +} + +static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, const json &extra) { std::stringstream ss_tid; ss_tid << std::this_thread::get_id(); - json log = json{ - {"tid", ss_tid.str()}, - {"timestamp", time(nullptr)}, - }; - if (server_log_json) + if (log_json) { - log.merge_patch({ - {"level", level}, + json log = json{ + {"msg", message}, +#if SERVER_VERBOSE + {"ts", time(nullptr)}, + {"level", log_level_to_string(level)}, + {"tid", ss_tid.str()}, {"function", function}, {"line", line}, - {"msg", message}, - }); +#endif + }; if (!extra.empty()) { log.merge_patch(extra); } - printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str()); + auto dump = log.dump(-1, ' ', false, json::error_handler_t::replace); + if (log_callback == nullptr) + { + printf("%s\n", dump.c_str()); + } else { + log_callback(level, dump.c_str(), nullptr); + } } else { - char buf[1024]; - snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); + std::stringstream ss; + ss << message; if (!extra.empty()) { - log.merge_patch(extra); - } - std::stringstream ss; - ss << buf << " |"; - for (const auto &el : log.items()) - { - const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); - ss << " " << el.key() << "=" << value; + for (const auto &el : extra.items()) + { + const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); + ss << " " << el.key() << "=" << value; + } } +#if SERVER_VERBOSE + ss << " | ts " << time(nullptr) + << " | tid " << ss_tid.str() + << " | " << function << " line " << line; +#endif + const std::string str = ss.str(); - printf("%.*s\n", (int)str.size(), str.data()); + if (log_callback == nullptr) { + printf("[%4s] %.*s\n", log_level_to_string(level), (int)str.size(), str.data()); + } else { + log_callback(level, str.c_str(), nullptr); + } } fflush(stdout); } @@ -638,7 +665,7 @@ static json format_embeddings_response_oaicompat(const json &request, const json { json data = json::array(); int i = 0; - for (auto &elem : embeddings) + for (const auto &elem : embeddings) { data.push_back( json{{"embedding", json_value(elem, "embedding", json::array())}, {"index", i++}, {"object", "embedding"}}); diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index aa1bb5ad..65fa29e5 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -1,7 +1,11 @@ package de.kherud.llama; +import de.kherud.llama.args.LogFormat; +import org.jetbrains.annotations.Nullable; + import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; +import java.util.function.BiConsumer; /** * This class is a wrapper around the llama.cpp functionality. @@ -93,6 +97,17 @@ public String decode(int[] tokens) { return new String(bytes, StandardCharsets.UTF_8); } + /** + * Sets a callback for native llama.cpp log messages. + * Per default, log messages are written to stdout. To only change the log format but keep logging to stdout, + * the given callback can be null. + * To disable logging, pass an empty callback, i.e., (level, msg) -> {}. + * + * @param format the log format to use + * @param callback a method to call for log messages + */ + public static native void setLogger(LogFormat format, @Nullable BiConsumer callback); + @Override public void close() { delete(); diff --git a/src/main/java/de/kherud/llama/LogLevel.java b/src/main/java/de/kherud/llama/LogLevel.java new file mode 100644 index 00000000..b55c0898 --- /dev/null +++ b/src/main/java/de/kherud/llama/LogLevel.java @@ -0,0 +1,13 @@ +package de.kherud.llama; + +/** + * This enum represents the native log levels of llama.cpp. + */ +public enum LogLevel { + + DEBUG, + INFO, + WARN, + ERROR + +} diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 67135de9..1cbb6973 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -3,7 +3,6 @@ import java.util.Map; import de.kherud.llama.args.GpuSplitMode; -import de.kherud.llama.args.LogFormat; import de.kherud.llama.args.NumaStrategy; import de.kherud.llama.args.PoolingType; import de.kherud.llama.args.RopeScalingType; @@ -53,8 +52,6 @@ public final class ModelParameters extends JsonParameters { private static final String PARAM_MODEL_URL = "model_url"; private static final String PARAM_HF_REPO = "hf_repo"; private static final String PARAM_HF_FILE = "hf_file"; - private static final String PARAM_LOGDIR = "logdir"; - private static final String PARAM_LOG_DISABLE = "disable_log"; private static final String PARAM_LOOKUP_CACHE_STATIC = "lookup_cache_static"; private static final String PARAM_LOOKUP_CACHE_DYNAMIC = "lookup_cache_dynamic"; private static final String PARAM_LORA_ADAPTER = "lora_adapter"; @@ -68,7 +65,6 @@ public final class ModelParameters extends JsonParameters { private static final String PARAM_USE_MLOCK = "use_mlock"; private static final String PARAM_NO_KV_OFFLOAD = "no_kv_offload"; private static final String PARAM_SYSTEM_PROMPT = "system_prompt"; - private static final String PARAM_LOG_FORMAT = "log_format"; private static final String PARAM_CHAT_TEMPLATE = "chat_template"; /** @@ -447,22 +443,6 @@ public ModelParameters setHuggingFaceFile(String hfFile) { return this; } - /** - * Set path under which to save YAML logs (no logging if unset) - */ - public ModelParameters setLogDirectory(String logdir) { - parameters.put(PARAM_LOGDIR, toJsonString(logdir)); - return this; - } - - /** - * Set whether to disable logging - */ - public ModelParameters setDisableLog(boolean logDisable) { - parameters.put(PARAM_LOG_DISABLE, String.valueOf(logDisable)); - return this; - } - /** * Set path to static lookup cache to use for lookup decoding (not updated by generation) */ @@ -584,24 +564,6 @@ public ModelParameters setSystemPrompt(String systemPrompt) { return this; } - /** - * Set which log format to use - */ - public ModelParameters setLogFormat(LogFormat logFormat) { - switch (logFormat) { - case NONE: - parameters.put(PARAM_LOG_DISABLE, String.valueOf(true)); - break; - case JSON: - parameters.put(PARAM_LOG_FORMAT, "json"); - break; - case TEXT: - parameters.put(PARAM_LOG_FORMAT, "text"); - break; - } - return this; - } - /** * The chat template to use (default: empty) */ diff --git a/src/main/java/de/kherud/llama/args/LogFormat.java b/src/main/java/de/kherud/llama/args/LogFormat.java index f0e76492..8a5b46e8 100644 --- a/src/main/java/de/kherud/llama/args/LogFormat.java +++ b/src/main/java/de/kherud/llama/args/LogFormat.java @@ -1,8 +1,11 @@ package de.kherud.llama.args; +/** + * The log output format (defaults to JSON for all server-based outputs). + */ public enum LogFormat { - NONE, JSON, TEXT + } diff --git a/src/test/java/examples/MainExample.java b/src/test/java/examples/MainExample.java index 65e20c12..92581144 100644 --- a/src/test/java/examples/MainExample.java +++ b/src/test/java/examples/MainExample.java @@ -17,8 +17,7 @@ public class MainExample { public static void main(String... args) throws IOException { ModelParameters modelParams = new ModelParameters() .setModelFilePath("models/mistral-7b-instruct-v0.2.Q2_K.gguf") - .setNGpuLayers(43) - .setDisableLog(true); + .setNGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + "requests immediately and with precision.\n\n" + From 9319bd124b649396a83b411c6443792eb57fb63b Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 25 May 2024 11:23:27 +0200 Subject: [PATCH 2/5] add logging unit tests --- .../java/de/kherud/llama/LlamaModelTest.java | 111 +++++++++++++++++- 1 file changed, 109 insertions(+), 2 deletions(-) diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 9659f975..a5454c59 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -1,8 +1,11 @@ package de.kherud.llama; -import java.util.HashMap; -import java.util.Map; +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.regex.Pattern; +import de.kherud.llama.args.LogFormat; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; @@ -18,6 +21,7 @@ public class LlamaModelTest { @BeforeClass public static void setup() { +// LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> System.out.println(level + ": " + msg)); model = new LlamaModel( new ModelParameters() .setModelFilePath("models/codellama-7b.Q2_K.gguf") @@ -161,4 +165,107 @@ public void testTokenization() { // the llama tokenizer adds a space before the prompt Assert.assertEquals(" " + prompt, decoded); } + + @Test + public void testLogText() { + List messages = new ArrayList<>(); + LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> messages.add(new LogMessage(level, msg))); + + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + model.complete(params); + + Assert.assertFalse(messages.isEmpty()); + + Pattern jsonPattern = Pattern.compile("^\\s*[\\[{].*[}\\]]\\s*$"); + for (LogMessage message : messages) { + Assert.assertNotNull(message.level); + Assert.assertFalse(jsonPattern.matcher(message.text).matches()); + } + } + + @Test + public void testLogJSON() { + List messages = new ArrayList<>(); + LlamaModel.setLogger(LogFormat.JSON, (level, msg) -> messages.add(new LogMessage(level, msg))); + + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + model.complete(params); + + Assert.assertFalse(messages.isEmpty()); + + Pattern jsonPattern = Pattern.compile("^\\s*[\\[{].*[}\\]]\\s*$"); + for (LogMessage message : messages) { + Assert.assertNotNull(message.level); + Assert.assertTrue(jsonPattern.matcher(message.text).matches()); + } + } + + @Test + public void testLogStdout() { + // Unfortunately, `printf` can't be easily re-directed to Java. This test only works manually, thus. + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + + System.out.println("########## Log Text ##########"); + LlamaModel.setLogger(LogFormat.TEXT, null); + model.complete(params); + + System.out.println("########## Log JSON ##########"); + LlamaModel.setLogger(LogFormat.JSON, null); + model.complete(params); + + System.out.println("########## Log None ##########"); + LlamaModel.setLogger(LogFormat.TEXT, (level, msg) -> {}); + model.complete(params); + + System.out.println("##############################"); + } + + private String completeAndReadStdOut() { + PrintStream stdOut = System.out; + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + @SuppressWarnings("ImplicitDefaultCharsetUsage") PrintStream printStream = new PrintStream(outputStream); + System.setOut(printStream); + + try { + InferenceParameters params = new InferenceParameters(prefix) + .setNPredict(nPredict) + .setSeed(42); + model.complete(params); + } finally { + System.out.flush(); + System.setOut(stdOut); + printStream.close(); + } + + return outputStream.toString(); + } + + private List splitLines(String text) { + List lines = new ArrayList<>(); + + Scanner scanner = new Scanner(text); + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + lines.add(line); + } + scanner.close(); + + return lines; + } + + private static final class LogMessage { + private final LogLevel level; + private final String text; + + private LogMessage(LogLevel level, String text) { + this.level = level; + this.text = text; + } + } } From 5e3c269b589c5605a3217d2ca1df71a70db6e3f5 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 25 May 2024 11:32:06 +0200 Subject: [PATCH 3/5] re-add jllama.h --- src/main/cpp/jllama.h | 85 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 src/main/cpp/jllama.h diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h new file mode 100644 index 00000000..2fd0529e --- /dev/null +++ b/src/main/cpp/jllama.h @@ -0,0 +1,85 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class de_kherud_llama_LlamaModel */ + +#ifndef _Included_de_kherud_llama_LlamaModel +#define _Included_de_kherud_llama_LlamaModel +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: de_kherud_llama_LlamaModel + * Method: embed + * Signature: (Ljava/lang/String;)[F + */ +JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed + (JNIEnv *, jobject, jstring); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: encode + * Signature: (Ljava/lang/String;)[I + */ +JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode + (JNIEnv *, jobject, jstring); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: setLogger + * Signature: (Lde/kherud/llama/args/LogFormat;Ljava/util/function/BiConsumer;)V + */ +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger + (JNIEnv *, jclass, jobject, jobject); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: requestCompletion + * Signature: (Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion + (JNIEnv *, jobject, jstring); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: receiveCompletion + * Signature: (I)Lde/kherud/llama/LlamaOutput; + */ +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion + (JNIEnv *, jobject, jint); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: cancelCompletion + * Signature: (I)V + */ +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion + (JNIEnv *, jobject, jint); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: decodeBytes + * Signature: ([I)[B + */ +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes + (JNIEnv *, jobject, jintArray); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: loadModel + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel + (JNIEnv *, jobject, jstring); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: delete + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete + (JNIEnv *, jobject); + +#ifdef __cplusplus +} +#endif +#endif From be41f901ace337001570c565bc9517d958906d2f Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 25 May 2024 11:32:31 +0200 Subject: [PATCH 4/5] format c++ code --- src/main/cpp/jllama.cpp | 28 +++++++++++++++++----------- src/main/cpp/server.hpp | 3 ++- src/main/cpp/utils.hpp | 37 +++++++++++++++++++++++-------------- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 4c087bf4..4cf62c33 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -13,7 +13,7 @@ namespace { -JavaVM* g_vm = nullptr; +JavaVM *g_vm = nullptr; // classes jclass c_llama_model = nullptr; @@ -117,7 +117,8 @@ jobject log_level_to_jobject(ggml_log_level level) return o_log_level_error; case GGML_LOG_LEVEL_WARN: return o_log_level_warn; - default: case GGML_LOG_LEVEL_INFO: + default: + case GGML_LOG_LEVEL_INFO: return o_log_level_info; case GGML_LOG_LEVEL_DEBUG: return o_log_level_debug; @@ -127,9 +128,11 @@ jobject log_level_to_jobject(ggml_log_level level) /** * Returns the JNIEnv of the current thread. */ -JNIEnv* get_jni_env() { - JNIEnv* env = nullptr; - if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { +JNIEnv *get_jni_env() +{ + JNIEnv *env = nullptr; + if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) + { throw std::runtime_error("Thread is not attached to the JVM"); } return env; @@ -436,10 +439,12 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo std::thread t([ctx_server]() { JNIEnv *env; - jint res = g_vm->GetEnv((void**)&env, JNI_VERSION_1_6); - if (res == JNI_EDETACHED) { - res = g_vm->AttachCurrentThread((void**)&env, nullptr); - if (res != JNI_OK) { + jint res = g_vm->GetEnv((void **)&env, JNI_VERSION_1_6); + if (res == JNI_EDETACHED) + { + res = g_vm->AttachCurrentThread((void **)&env, nullptr); + if (res != JNI_OK) + { throw std::runtime_error("Failed to attach thread to JVM"); } } @@ -459,7 +464,8 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv json json_params = json::parse(c_params); const bool infill = json_params.contains("input_prefix") || json_params.contains("input_suffix"); - if (json_params.value("use_chat_template", false)) { + if (json_params.value("use_chat_template", false)) + { json chat; chat.push_back({{"role", "system"}, {"content", ctx_server->system_prompt}}); chat.push_back({{"role", "user"}, {"content", json_params["prompt"]}}); @@ -631,7 +637,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jc { o_log_callback = env->NewGlobalRef(jcallback); log_callback = [](enum ggml_log_level level, const char *text, void *user_data) { - JNIEnv* env = get_jni_env(); + JNIEnv *env = get_jni_env(); jstring message = env->NewStringUTF(text); jobject log_level = log_level_to_jobject(level); env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message); diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index b111bb7d..d3d4750a 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -2139,7 +2139,8 @@ struct server_context slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + send_error(slot, "input is too large to process. increase the physical batch size", + ERROR_TYPE_SERVER); continue; } } diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp index 56f6742a..ad7198c1 100644 --- a/src/main/cpp/utils.hpp +++ b/src/main/cpp/utils.hpp @@ -70,20 +70,24 @@ template static T json_value(const json &body, const std::string &k } } -static const char * log_level_to_string(ggml_log_level level) { - switch (level) { - case GGML_LOG_LEVEL_ERROR: - return "ERROR"; - case GGML_LOG_LEVEL_WARN: - return "WARN"; - default: case GGML_LOG_LEVEL_INFO: - return "INFO"; - case GGML_LOG_LEVEL_DEBUG: - return "DEBUG"; +static const char *log_level_to_string(ggml_log_level level) +{ + switch (level) + { + case GGML_LOG_LEVEL_ERROR: + return "ERROR"; + case GGML_LOG_LEVEL_WARN: + return "WARN"; + default: + case GGML_LOG_LEVEL_INFO: + return "INFO"; + case GGML_LOG_LEVEL_DEBUG: + return "DEBUG"; } } -static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, const json &extra) +static inline void server_log(ggml_log_level level, const char *function, int line, const char *message, + const json &extra) { std::stringstream ss_tid; ss_tid << std::this_thread::get_id(); @@ -110,7 +114,9 @@ static inline void server_log(ggml_log_level level, const char *function, int li if (log_callback == nullptr) { printf("%s\n", dump.c_str()); - } else { + } + else + { log_callback(level, dump.c_str(), nullptr); } } @@ -135,9 +141,12 @@ static inline void server_log(ggml_log_level level, const char *function, int li #endif const std::string str = ss.str(); - if (log_callback == nullptr) { + if (log_callback == nullptr) + { printf("[%4s] %.*s\n", log_level_to_string(level), (int)str.size(), str.data()); - } else { + } + else + { log_callback(level, str.c_str(), nullptr); } } From 7c97036a31a28b3b3e445f9d3dea8ba1994de7b0 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 25 May 2024 11:47:14 +0200 Subject: [PATCH 5/5] Update readme logging --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index 6ea1df8e..09e9dfef 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,23 @@ try (LlamaModel model = new LlamaModel(modelParams)) { } ``` +### Logging + +Per default, logs are written to stdout. +This can be intercepted via the static method `LlamaModel.setLogger(LogFormat, BiConsumer)`. +There is text- and JSON-based logging. The default is JSON. +To only change the log format while still writing to stdout, `null` can be passed for the callback. +Logging can be disabled by passing an empty callback. + +```java +// Re-direct log messages however you like (e.g. to a logging library) +LlamaModel.setLogger(LogFormat.TEXT, (level, message) -> System.out.println(level.name() + ": " + message)); +// Log to stdout, but change the format +LlamaModel.setLogger(LogFormat.TEXT, null); +// Disable logging by passing a no-op +LlamaModel.setLogger(null, (level, message) -> {}); +``` + ## Importing in Android You can use this library in Android project.