Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Upgrade to llama.cpp b2797 #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ FetchContent_MakeAvailable(json)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG b2702
GIT_TAG b2797
)
FetchContent_MakeAvailable(llama.cpp)

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
![Java 11+](https://img.shields.io/badge/Java-11%2B-informational)
![llama.cpp b2702](https://img.shields.io/badge/llama.cpp-%23b2702-informational)
![llama.cpp b2797](https://img.shields.io/badge/llama.cpp-%23b2797-informational)

# Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp)

Expand All @@ -18,7 +18,7 @@ This repository provides Java bindings for the C++ library.
3. [Android](#importing-in-android)

> [!NOTE]
> Now with Llama 3 support
> Now with support for Llama 3, Phi-3, and flash attention

## Quick Start

Expand All @@ -28,7 +28,7 @@ Access this library via Maven:
<dependency>
<groupId>de.kherud</groupId>
<artifactId>llama</artifactId>
<version>3.0.1</version>
<version>3.0.2</version>
</dependency>
```

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>de.kherud</groupId>
<artifactId>llama</artifactId>
<version>3.0.1</version>
<version>3.0.2</version>
<packaging>jar</packaging>

<name>${project.groupId}:${project.artifactId}</name>
Expand Down
33 changes: 29 additions & 4 deletions src/main/cpp/server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ struct server_context
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.params.seed = json_value(data, "seed", default_params.seed);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
Expand Down Expand Up @@ -1209,7 +1209,7 @@ struct server_context
bool process_token(completion_token_output &result, server_slot &slot)
{
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = llama_token_to_piece(ctx, result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
slot.sampled = result.tok;

// search stop word and delete it
Expand Down Expand Up @@ -1314,6 +1314,27 @@ struct server_context
LOG_VERBOSE("eos token found", {});
}

auto n_ctx_train = llama_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
LOG_WARNING("n_predict is not set and self-context extend is disabled."
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
{ "id_slot", slot.id },
{ "params.n_predict", slot.params.n_predict },
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
{ "slot.n_decoded", slot.n_decoded },
{ "slot.n_predict", slot.n_predict },
{ "n_slots", params.n_parallel },
{ "slot.n_ctx", slot.n_ctx },
{ "n_ctx", n_ctx },
{ "n_ctx_train", n_ctx_train },
{ "ga_n", slot.ga_n },
});
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false; // stop prediction
}

LOG_VERBOSE("next token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
Expand Down Expand Up @@ -1475,8 +1496,9 @@ struct server_context
{
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);

size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(),
slot.generated_token_probs.end() - stop_word_toks.size());
slot.generated_token_probs.end() - safe_offset);
}
else
{
Expand Down Expand Up @@ -2313,7 +2335,7 @@ struct server_context
});

// process the created batch of tokens
for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch)
for (int32_t i = 0; i < batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

Expand Down Expand Up @@ -2534,6 +2556,7 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
params.embedding = json_value(jparams, "embedding", default_params.embedding);
params.escape = json_value(jparams, "escape", default_params.escape);
params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching);
params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn);
params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos);
params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos);
params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap);
Expand Down Expand Up @@ -2596,4 +2619,6 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {});
#endif
}

gpt_params_handle_model_default(params);
}
9 changes: 9 additions & 0 deletions src/main/java/de/kherud/llama/ModelParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public final class ModelParameters extends JsonParameters {
private static final String PARAM_LORA_BASE = "lora_base";
private static final String PARAM_EMBEDDING = "embedding";
private static final String PARAM_CONT_BATCHING = "cont_batching";
private static final String PARAM_FLASH_ATTENTION = "flash_attn";
private static final String PARAM_INPUT_PREFIX_BOS = "input_prefix_bos";
private static final String PARAM_IGNORE_EOS = "ignore_eos";
private static final String PARAM_USE_MMAP = "use_mmap";
Expand Down Expand Up @@ -526,6 +527,14 @@ public ModelParameters setContinuousBatching(boolean contBatching) {
return this;
}

/**
* Whether to enable Flash Attention (default: disabled)
*/
public ModelParameters setFlashAttention(boolean flashAttention) {
parameters.put(PARAM_FLASH_ATTENTION, String.valueOf(flashAttention));
return this;
}

/**
* Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string
*/
Expand Down