Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Implement Logging #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ project(jllama CXX)
include(FetchContent)

set(BUILD_SHARED_LIBS ON)
set(LLAMA_STATIC OFF)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

option(LLAMA_VERBOSE "llama: verbose output" OFF)

Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,23 @@ try (LlamaModel model = new LlamaModel(modelParams)) {
}
```

### Logging

Per default, logs are written to stdout.
This can be intercepted via the static method `LlamaModel.setLogger(LogFormat, BiConsumer<LogLevel, String>)`.
There is text- and JSON-based logging. The default is JSON.
To only change the log format while still writing to stdout, `null` can be passed for the callback.
Logging can be disabled by passing an empty callback.

```java
// Re-direct log messages however you like (e.g. to a logging library)
LlamaModel.setLogger(LogFormat.TEXT, (level, message) -> System.out.println(level.name() + ": " + message));
// Log to stdout, but change the format
LlamaModel.setLogger(LogFormat.TEXT, null);
// Disable logging by passing a no-op
LlamaModel.setLogger(null, (level, message) -> {});
```

## Importing in Android

You can use this library in Android project.
Expand Down
161 changes: 146 additions & 15 deletions src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include "jllama.h"

#include "nlohmann/json.hpp"
#include "llama.h"
#include "nlohmann/json.hpp"
#include "server.hpp"

#include <functional>
#include <stdexcept>

// We store some references to Java classes and their fields/methods here to speed up things for later and to fail
Expand All @@ -12,7 +13,7 @@

namespace
{
// JavaVM *g_vm = nullptr;
JavaVM *g_vm = nullptr;

// classes
jclass c_llama_model = nullptr;
Expand All @@ -29,6 +30,8 @@ jclass c_integer = nullptr;
jclass c_float = nullptr;
jclass c_biconsumer = nullptr;
jclass c_llama_error = nullptr;
jclass c_log_level = nullptr;
jclass c_log_format = nullptr;
jclass c_error_oom = nullptr;

// constructors
Expand All @@ -55,9 +58,22 @@ jfieldID f_model_pointer = nullptr;
jfieldID f_task_id = nullptr;
jfieldID f_utf_8 = nullptr;
jfieldID f_iter_has_next = nullptr;
jfieldID f_log_level_debug = nullptr;
jfieldID f_log_level_info = nullptr;
jfieldID f_log_level_warn = nullptr;
jfieldID f_log_level_error = nullptr;
jfieldID f_log_format_json = nullptr;
jfieldID f_log_format_text = nullptr;

// objects
jobject o_utf_8 = nullptr;
jobject o_log_level_debug = nullptr;
jobject o_log_level_info = nullptr;
jobject o_log_level_warn = nullptr;
jobject o_log_level_error = nullptr;
jobject o_log_format_json = nullptr;
jobject o_log_format_text = nullptr;
jobject o_log_callback = nullptr;

/**
* Convert a Java string to a std::string
Expand Down Expand Up @@ -89,8 +105,43 @@ jbyteArray parse_jbytes(JNIEnv *env, const std::string &string)
env->SetByteArrayRegion(bytes, 0, length, reinterpret_cast<const jbyte *>(string.c_str()));
return bytes;
}

/**
* Map a llama.cpp log level to its Java enumeration option.
*/
jobject log_level_to_jobject(ggml_log_level level)
{
switch (level)
{
case GGML_LOG_LEVEL_ERROR:
return o_log_level_error;
case GGML_LOG_LEVEL_WARN:
return o_log_level_warn;
default:
case GGML_LOG_LEVEL_INFO:
return o_log_level_info;
case GGML_LOG_LEVEL_DEBUG:
return o_log_level_debug;
}
}

/**
* Returns the JNIEnv of the current thread.
*/
JNIEnv *get_jni_env()
{
JNIEnv *env = nullptr;
if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK)
{
throw std::runtime_error("Thread is not attached to the JVM");
}
return env;
}
} // namespace

bool log_json;
std::function<void(ggml_log_level, const char *, void *)> log_callback;

/**
* The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`).
* `JNI_OnLoad` must return the JNI version needed by the native library.
Expand All @@ -101,6 +152,7 @@ jbyteArray parse_jbytes(JNIEnv *env, const std::string &string)
*/
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
{
g_vm = vm;
JNIEnv *env = nullptr;

if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1))
Expand All @@ -123,10 +175,13 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
c_float = env->FindClass("java/lang/Float");
c_biconsumer = env->FindClass("java/util/function/BiConsumer");
c_llama_error = env->FindClass("de/kherud/llama/LlamaException");
c_log_level = env->FindClass("de/kherud/llama/LogLevel");
c_log_format = env->FindClass("de/kherud/llama/args/LogFormat");
c_error_oom = env->FindClass("java/lang/OutOfMemoryError");

if (!(c_llama_model && c_llama_iterator && c_standard_charsets && c_output && c_string && c_hash_map && c_map &&
c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_error_oom))
c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level &&
c_log_format && c_error_oom))
{
goto error;
}
Expand All @@ -145,6 +200,8 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
c_float = (jclass)env->NewGlobalRef(c_float);
c_biconsumer = (jclass)env->NewGlobalRef(c_biconsumer);
c_llama_error = (jclass)env->NewGlobalRef(c_llama_error);
c_log_level = (jclass)env->NewGlobalRef(c_log_level);
c_log_format = (jclass)env->NewGlobalRef(c_log_format);
c_error_oom = (jclass)env->NewGlobalRef(c_error_oom);

// find constructors
Expand Down Expand Up @@ -182,34 +239,56 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
f_task_id = env->GetFieldID(c_llama_iterator, "taskId", "I");
f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;");
f_iter_has_next = env->GetFieldID(c_llama_iterator, "hasNext", "Z");

if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next))
f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lde/kherud/llama/LogLevel;");
f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lde/kherud/llama/LogLevel;");
f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lde/kherud/llama/LogLevel;");
f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lde/kherud/llama/LogLevel;");
f_log_format_json = env->GetStaticFieldID(c_log_format, "JSON", "Lde/kherud/llama/args/LogFormat;");
f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lde/kherud/llama/args/LogFormat;");

if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next && f_log_level_debug && f_log_level_info &&
f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text))
{
goto error;
}

o_utf_8 = env->NewStringUTF("UTF-8");

if (!(o_utf_8))
o_log_level_debug = env->GetStaticObjectField(c_log_level, f_log_level_debug);
o_log_level_info = env->GetStaticObjectField(c_log_level, f_log_level_info);
o_log_level_warn = env->GetStaticObjectField(c_log_level, f_log_level_warn);
o_log_level_error = env->GetStaticObjectField(c_log_level, f_log_level_error);
o_log_format_json = env->GetStaticObjectField(c_log_format, f_log_format_json);
o_log_format_text = env->GetStaticObjectField(c_log_format, f_log_format_text);

if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error &&
o_log_format_json && o_log_format_text))
{
goto error;
}

o_utf_8 = (jclass)env->NewGlobalRef(o_utf_8);
o_utf_8 = env->NewGlobalRef(o_utf_8);
o_log_level_debug = env->NewGlobalRef(o_log_level_debug);
o_log_level_info = env->NewGlobalRef(o_log_level_info);
o_log_level_warn = env->NewGlobalRef(o_log_level_warn);
o_log_level_error = env->NewGlobalRef(o_log_level_error);
o_log_format_json = env->NewGlobalRef(o_log_format_json);
o_log_format_text = env->NewGlobalRef(o_log_format_text);

if (env->ExceptionCheck())
{
env->ExceptionDescribe();
goto error;
}

llama_backend_init();

goto success;

error:
return JNI_ERR;

success:
return JNI_VERSION_1_2;
return JNI_VERSION_1_6;
}

/**
Expand All @@ -224,7 +303,7 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved)
{
JNIEnv *env = nullptr;

if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1))
if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_6))
{
return;
}
Expand All @@ -242,9 +321,24 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved)
env->DeleteGlobalRef(c_float);
env->DeleteGlobalRef(c_biconsumer);
env->DeleteGlobalRef(c_llama_error);
env->DeleteGlobalRef(c_log_level);
env->DeleteGlobalRef(c_log_level);
env->DeleteGlobalRef(c_error_oom);

env->DeleteGlobalRef(o_utf_8);
env->DeleteGlobalRef(o_log_level_debug);
env->DeleteGlobalRef(o_log_level_info);
env->DeleteGlobalRef(o_log_level_warn);
env->DeleteGlobalRef(o_log_level_error);
env->DeleteGlobalRef(o_log_format_json);
env->DeleteGlobalRef(o_log_format_text);

if (o_log_callback != nullptr)
{
env->DeleteGlobalRef(o_log_callback);
}

llama_backend_free();
}

JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams)
Expand Down Expand Up @@ -277,7 +371,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
params.model_alias = params.model;
}

llama_backend_init();
llama_numa_init(params.numa);

LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}});
Expand Down Expand Up @@ -344,7 +437,19 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3));

std::thread t([ctx_server]() { ctx_server->queue_tasks.start_loop(); });
std::thread t([ctx_server]() {
JNIEnv *env;
jint res = g_vm->GetEnv((void **)&env, JNI_VERSION_1_6);
if (res == JNI_EDETACHED)
{
res = g_vm->AttachCurrentThread((void **)&env, nullptr);
if (res != JNI_OK)
{
throw std::runtime_error("Failed to attach thread to JVM");
}
}
ctx_server->queue_tasks.start_loop();
});
t.detach();

env->SetLongField(obj, f_model_pointer, reinterpret_cast<jlong>(ctx_server));
Expand All @@ -359,7 +464,8 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
json json_params = json::parse(c_params);
const bool infill = json_params.contains("input_prefix") || json_params.contains("input_suffix");

if (json_params.value("use_chat_template", false)) {
if (json_params.value("use_chat_template", false))
{
json chat;
chat.push_back({{"role", "system"}, {"content", ctx_server->system_prompt}});
chat.push_back({{"role", "user"}, {"content", json_params["prompt"]}});
Expand Down Expand Up @@ -502,8 +608,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobje
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
ctx_server->queue_tasks.terminate();
// maybe we should keep track how many models were loaded before freeing the backend
llama_backend_free();
delete ctx_server;
}

Expand All @@ -514,3 +618,30 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *
ctx_server->request_cancel(id_task);
ctx_server->queue_results.remove_waiting_task_id(id_task);
}

JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jclass clazz, jobject log_format,
jobject jcallback)
{
if (o_log_callback != nullptr)
{
env->DeleteGlobalRef(o_log_callback);
}

log_json = env->IsSameObject(log_format, o_log_format_json);

if (jcallback == nullptr)
{
log_callback = nullptr;
}
else
{
o_log_callback = env->NewGlobalRef(jcallback);
log_callback = [](enum ggml_log_level level, const char *text, void *user_data) {
JNIEnv *env = get_jni_env();
jstring message = env->NewStringUTF(text);
jobject log_level = log_level_to_jobject(level);
env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message);
env->DeleteLocalRef(message);
};
}
}
8 changes: 8 additions & 0 deletions src/main/cpp/jllama.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions src/main/cpp/server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

using json = nlohmann::ordered_json;

bool server_log_json = true;

enum stop_type
{
STOP_TYPE_FULL,
Expand Down Expand Up @@ -2141,7 +2139,8 @@ struct server_context
slot.command = SLOT_COMMAND_NONE;
slot.release();
slot.print_timings();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
send_error(slot, "input is too large to process. increase the physical batch size",
ERROR_TYPE_SERVER);
continue;
}
}
Expand Down Expand Up @@ -2579,7 +2578,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
params.input_prefix = json_value(jparams, "input_prefix", default_params.input_prefix);
params.input_suffix = json_value(jparams, "input_suffix", default_params.input_suffix);
params.antiprompt = json_value(jparams, "antiprompt", default_params.antiprompt);
params.logdir = json_value(jparams, "logdir", default_params.logdir);
params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static);
params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic);
params.logits_file = json_value(jparams, "logits_file", default_params.logits_file);
Expand All @@ -2594,7 +2592,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap);
params.use_mlock = json_value(jparams, "use_mlock", default_params.use_mlock);
params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload);
server_log_json = !jparams.contains("log_format") || jparams["log_format"] == "json";
sparams.system_prompt = json_value(jparams, "system_prompt", default_sparams.system_prompt);
sparams.chat_template = json_value(jparams, "chat_template", default_sparams.chat_template);

Expand Down
Loading