diff --git a/CMakeLists.txt b/CMakeLists.txt
index c1873c20..395f37ee 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -22,7 +22,7 @@ FetchContent_MakeAvailable(json)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
- GIT_TAG b2702
+ GIT_TAG b2797
)
FetchContent_MakeAvailable(llama.cpp)
diff --git a/README.md b/README.md
index 7fbc6e44..afedb0fc 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@

-
+
# Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp)
@@ -18,7 +18,7 @@ This repository provides Java bindings for the C++ library.
3. [Android](#importing-in-android)
> [!NOTE]
-> Now with Llama 3 support
+> Now with support for Llama 3, Phi-3, and flash attention
## Quick Start
@@ -28,7 +28,7 @@ Access this library via Maven:
de.kherud
llama
- 3.0.1
+ 3.0.2
```
diff --git a/pom.xml b/pom.xml
index 66b9eb6c..c111bb7c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
de.kherud
llama
- 3.0.1
+ 3.0.2
jar
${project.groupId}:${project.artifactId}
diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp
index 8295f42a..4c58e548 100644
--- a/src/main/cpp/server.hpp
+++ b/src/main/cpp/server.hpp
@@ -910,7 +910,7 @@ struct server_context
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
- slot.params.seed = json_value(data, "seed", default_params.seed);
+ slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
@@ -1209,7 +1209,7 @@ struct server_context
bool process_token(completion_token_output &result, server_slot &slot)
{
// remember which tokens were sampled - used for repetition penalties during sampling
- const std::string token_str = llama_token_to_piece(ctx, result.tok);
+ const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
slot.sampled = result.tok;
// search stop word and delete it
@@ -1314,6 +1314,27 @@ struct server_context
LOG_VERBOSE("eos token found", {});
}
+ auto n_ctx_train = llama_n_ctx_train(model);
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
+ && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
+ LOG_WARNING("n_predict is not set and self-context extend is disabled."
+ " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
+ { "id_slot", slot.id },
+ { "params.n_predict", slot.params.n_predict },
+ { "slot.n_prompt_tokens", slot.n_prompt_tokens },
+ { "slot.n_decoded", slot.n_decoded },
+ { "slot.n_predict", slot.n_predict },
+ { "n_slots", params.n_parallel },
+ { "slot.n_ctx", slot.n_ctx },
+ { "n_ctx", n_ctx },
+ { "n_ctx_train", n_ctx_train },
+ { "ga_n", slot.ga_n },
+ });
+ slot.truncated = true;
+ slot.stopped_limit = true;
+ slot.has_next_token = false; // stop prediction
+ }
+
LOG_VERBOSE("next token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
@@ -1475,8 +1496,9 @@ struct server_context
{
const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
+ size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector(slot.generated_token_probs.begin(),
- slot.generated_token_probs.end() - stop_word_toks.size());
+ slot.generated_token_probs.end() - safe_offset);
}
else
{
@@ -2313,7 +2335,7 @@ struct server_context
});
// process the created batch of tokens
- for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch)
+ for (int32_t i = 0; i < batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -2534,6 +2556,7 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
params.embedding = json_value(jparams, "embedding", default_params.embedding);
params.escape = json_value(jparams, "escape", default_params.escape);
params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching);
+ params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn);
params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos);
params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos);
params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap);
@@ -2596,4 +2619,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {});
#endif
}
+
+ gpt_params_handle_model_default(params);
}
diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java
index da38d409..8257dc22 100644
--- a/src/main/java/de/kherud/llama/ModelParameters.java
+++ b/src/main/java/de/kherud/llama/ModelParameters.java
@@ -61,6 +61,7 @@ public final class ModelParameters extends JsonParameters {
private static final String PARAM_LORA_BASE = "lora_base";
private static final String PARAM_EMBEDDING = "embedding";
private static final String PARAM_CONT_BATCHING = "cont_batching";
+ private static final String PARAM_FLASH_ATTENTION = "flash_attn";
private static final String PARAM_INPUT_PREFIX_BOS = "input_prefix_bos";
private static final String PARAM_IGNORE_EOS = "ignore_eos";
private static final String PARAM_USE_MMAP = "use_mmap";
@@ -526,6 +527,14 @@ public ModelParameters setContinuousBatching(boolean contBatching) {
return this;
}
+ /**
+ * Whether to enable Flash Attention (default: disabled)
+ */
+ public ModelParameters setFlashAttention(boolean flashAttention) {
+ parameters.put(PARAM_FLASH_ATTENTION, String.valueOf(flashAttention));
+ return this;
+ }
+
/**
* Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string
*/