diff --git a/CMakeLists.txt b/CMakeLists.txt index c1873c20..395f37ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ FetchContent_MakeAvailable(json) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2702 + GIT_TAG b2797 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/README.md b/README.md index 7fbc6e44..afedb0fc 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ ![Java 11+](https://img.shields.io/badge/Java-11%2B-informational) -![llama.cpp b2702](https://img.shields.io/badge/llama.cpp-%23b2702-informational) +![llama.cpp b2797](https://img.shields.io/badge/llama.cpp-%23b2797-informational) # Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) @@ -18,7 +18,7 @@ This repository provides Java bindings for the C++ library. 3. [Android](#importing-in-android) > [!NOTE] -> Now with Llama 3 support +> Now with support for Llama 3, Phi-3, and flash attention ## Quick Start @@ -28,7 +28,7 @@ Access this library via Maven: de.kherud llama - 3.0.1 + 3.0.2 ``` diff --git a/pom.xml b/pom.xml index 66b9eb6c..c111bb7c 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ de.kherud llama - 3.0.1 + 3.0.2 jar ${project.groupId}:${project.artifactId} diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 8295f42a..4c58e548 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -910,7 +910,7 @@ struct server_context slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.params.seed = json_value(data, "seed", default_params.seed); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); @@ -1209,7 +1209,7 @@ struct server_context bool process_token(completion_token_output &result, server_slot &slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = llama_token_to_piece(ctx, result.tok); + const std::string token_str = llama_token_to_piece(ctx, result.tok, false); slot.sampled = result.tok; // search stop word and delete it @@ -1314,6 +1314,27 @@ struct server_context LOG_VERBOSE("eos token found", {}); } + auto n_ctx_train = llama_n_ctx_train(model); + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 + && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + LOG_WARNING("n_predict is not set and self-context extend is disabled." + " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", { + { "id_slot", slot.id }, + { "params.n_predict", slot.params.n_predict }, + { "slot.n_prompt_tokens", slot.n_prompt_tokens }, + { "slot.n_decoded", slot.n_decoded }, + { "slot.n_predict", slot.n_predict }, + { "n_slots", params.n_parallel }, + { "slot.n_ctx", slot.n_ctx }, + { "n_ctx", n_ctx }, + { "n_ctx_train", n_ctx_train }, + { "ga_n", slot.ga_n }, + }); + slot.truncated = true; + slot.stopped_limit = true; + slot.has_next_token = false; // stop prediction + } + LOG_VERBOSE("next token", { {"id_slot", slot.id}, {"id_task", slot.id_task}, @@ -1475,8 +1496,9 @@ struct server_context { const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); probs = std::vector(slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - stop_word_toks.size()); + slot.generated_token_probs.end() - safe_offset); } else { @@ -2313,7 +2335,7 @@ struct server_context }); // process the created batch of tokens - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -2534,6 +2556,7 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params params.embedding = json_value(jparams, "embedding", default_params.embedding); params.escape = json_value(jparams, "escape", default_params.escape); params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching); + params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn); params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos); params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos); params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap); @@ -2596,4 +2619,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {}); #endif } + + gpt_params_handle_model_default(params); } diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index da38d409..8257dc22 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -61,6 +61,7 @@ public final class ModelParameters extends JsonParameters { private static final String PARAM_LORA_BASE = "lora_base"; private static final String PARAM_EMBEDDING = "embedding"; private static final String PARAM_CONT_BATCHING = "cont_batching"; + private static final String PARAM_FLASH_ATTENTION = "flash_attn"; private static final String PARAM_INPUT_PREFIX_BOS = "input_prefix_bos"; private static final String PARAM_IGNORE_EOS = "ignore_eos"; private static final String PARAM_USE_MMAP = "use_mmap"; @@ -526,6 +527,14 @@ public ModelParameters setContinuousBatching(boolean contBatching) { return this; } + /** + * Whether to enable Flash Attention (default: disabled) + */ + public ModelParameters setFlashAttention(boolean flashAttention) { + parameters.put(PARAM_FLASH_ATTENTION, String.valueOf(flashAttention)); + return this; + } + /** * Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string */