diff --git a/.clang-format b/.clang-format
index 1d24348d..a113c01c 100644
--- a/.clang-format
+++ b/.clang-format
@@ -1,6 +1,6 @@
---
Language: Cpp
-# BasedOnStyle: Microsoft
+# BasedOnStyle: LLVM
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignArrayOfStructures: None
@@ -28,11 +28,6 @@ AlignConsecutiveMacros:
AcrossComments: false
AlignCompound: false
PadOperators: false
-AlignConsecutiveShortCaseStatements:
- Enabled: false
- AcrossEmptyLines: false
- AcrossComments: false
- AlignCaseColons: false
AlignEscapedNewlines: Right
AlignOperands: Align
AlignTrailingComments:
@@ -42,8 +37,8 @@ AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
-AllowShortEnumsOnASingleLine: false
-AllowShortFunctionsOnASingleLine: None
+AllowShortEnumsOnASingleLine: true
+AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: false
@@ -58,17 +53,17 @@ BinPackParameters: true
BitFieldColonSpacing: Both
BraceWrapping:
AfterCaseLabel: false
- AfterClass: true
- AfterControlStatement: Always
- AfterEnum: true
- AfterExternBlock: true
- AfterFunction: true
- AfterNamespace: true
- AfterObjCDeclaration: true
- AfterStruct: true
+ AfterClass: false
+ AfterControlStatement: Never
+ AfterEnum: false
+ AfterExternBlock: false
+ AfterFunction: false
+ AfterNamespace: false
+ AfterObjCDeclaration: false
+ AfterStruct: false
AfterUnion: false
- BeforeCatch: true
- BeforeElse: true
+ BeforeCatch: false
+ BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
@@ -80,7 +75,7 @@ BreakAfterJavaFieldAnnotations: false
BreakArrays: true
BreakBeforeBinaryOperators: None
BreakBeforeConceptDeclarations: Always
-BreakBeforeBraces: Custom
+BreakBeforeBraces: Attach
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeColon
@@ -142,7 +137,6 @@ IntegerLiteralSeparator:
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: true
-KeepEmptyLinesAtEOF: false
LambdaBodyIndentation: Signature
LineEnding: DeriveLF
MacroBlockBegin: ''
@@ -150,7 +144,7 @@ MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
-ObjCBlockIndentWidth: 2
+ObjCBlockIndentWidth: 4
ObjCBreakBeforeNestedBlockParam: true
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
@@ -164,14 +158,13 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyIndentedWhitespace: 0
-PenaltyReturnTypeOnItsOwnLine: 1000
+PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Right
PPIndentWidth: -1
QualifierAlignment: Leave
ReferenceAlignment: Pointer
ReflowComments: true
RemoveBracesLLVM: false
-RemoveParentheses: Leave
RemoveSemicolon: false
RequiresClausePosition: OwnLine
RequiresExpressionIndentation: OuterScope
@@ -189,7 +182,6 @@ SpaceBeforeCaseColon: false
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
-SpaceBeforeJsonColon: false
SpaceBeforeParens: ControlStatements
SpaceBeforeParensOptions:
AfterControlStatements: true
@@ -204,18 +196,16 @@ SpaceBeforeParensOptions:
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
+SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: Never
+SpacesInConditionalStatement: false
SpacesInContainerLiterals: true
+SpacesInCStyleCastParentheses: false
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: -1
-SpacesInParens: Never
-SpacesInParensOptions:
- InCStyleCasts: false
- InConditionalStatements: false
- InEmptyParentheses: false
- Other: false
+SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Latest
StatementAttributeLikeMacros:
@@ -223,9 +213,8 @@ StatementAttributeLikeMacros:
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
-TabWidth: 4
+TabWidth: 8
UseTab: Never
-VerilogBreakBetweenInstancePorts: true
WhitespaceSensitiveMacros:
- BOOST_PP_STRINGIZE
- CF_SWIFT_NAME
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 1db8b696..a15f809d 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -1,14 +1,15 @@
-# This work flow runs all Java tests for continuous integration.
-# Since it has to build llama.cpp first, for speed, it only runs / tests on the natively supported GitHub runners.
-
+---
name: Continuous Integration
-on: [ "pull_request", "workflow_dispatch" ]
+on:
+ - pull_request
+ - workflow_dispatch
env:
- MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf"
- MODEL_NAME: "codellama-7b.Q2_K.gguf"
+ MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf
+ MODEL_NAME: codellama-7b.Q2_K.gguf
+ RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf
+ RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf
jobs:
- # don't split build and test jobs to keep the workflow simple
build-and-test-linux:
name: ubuntu-latest
runs-on: ubuntu-latest
@@ -16,15 +17,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4
with:
- distribution: 'zulu'
- java-version: '11'
+ distribution: zulu
+ java-version: "11"
- name: Build libraries
- # cmake should figure out OS and ARCH automatically when running build.sh (but we need mvn compile for it)
run: |
mvn compile
.github/build.sh -DLLAMA_VERBOSE=ON
- - name: Download model
+ - name: Download text generation model
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
+ - name: Download reranking model
+ run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
+ - name: List files in models directory
+ run: ls -l models/
- name: Run tests
run: mvn test
- if: failure()
@@ -41,26 +45,26 @@ jobs:
fail-fast: false
matrix:
target:
- - {
- runner: macos-13,
- cmake: '-DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON'
- }
- - {
- runner: macos-14,
- cmake: '-DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON'
- }
+ - runner: macos-13
+ cmake: -DLLAMA_METAL=OFF -DLLAMA_VERBOSE=ON
+ - runner: macos-14
+ cmake: -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_VERBOSE=ON
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4
with:
- distribution: 'zulu'
- java-version: '11'
+ distribution: zulu
+ java-version: "11"
- name: Build libraries
run: |
mvn compile
.github/build.sh ${{ matrix.target.cmake }}
- - name: Download model
+ - name: Download text generaton model model
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
+ - name: Download reranking model
+ run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
+ - name: List files in models directory
+ run: ls -l models/
- name: Run tests
run: mvn test
- if: failure()
@@ -71,8 +75,8 @@ jobs:
if-no-files-found: warn
build-and-test-windows:
- name: windows-latest
- runs-on: windows-latest
+ name: windows-2019
+ runs-on: windows-2019
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4
@@ -85,11 +89,17 @@ jobs:
.github\build.bat -DLLAMA_VERBOSE=ON
- name: Download model
run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME
+ - name: Download reranking model
+ run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME
+ - name: List files in models directory
+ run: ls -l models/
- name: Run tests
run: mvn test
- if: failure()
uses: actions/upload-artifact@v4
with:
- name: error-log-windows
- path: ${{ github.workspace }}\hs_err_pid*.log
+ name: windows-output
+ path: |
+ ${{ github.workspace }}\hs_err_pid*.log
+ ${{ github.workspace }}/src/main/resources/de/kherud/llama/**/*
if-no-files-found: warn
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 85829ed9..64032028 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -11,22 +11,25 @@ on:
env:
MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf"
MODEL_NAME: "codellama-7b.Q2_K.gguf"
+ RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf"
+ RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf"
jobs:
- build-linux-cuda:
- name: Build Linux x86-64 CUDA12
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - name: Build libraries
- shell: bash
- run: |
- .github/dockcross/dockcross-manylinux_2_28-x64 .github/build_cuda_linux.sh "-DOS_NAME=Linux -DOS_ARCH=x86_64"
- - name: Upload artifacts
- uses: actions/upload-artifact@v4
- with:
- name: linux-libraries-cuda
- path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/
+# todo: doesn't work with the newest llama.cpp version
+# build-linux-cuda:
+# name: Build Linux x86-64 CUDA12
+# runs-on: ubuntu-latest
+# steps:
+# - uses: actions/checkout@v4
+# - name: Build libraries
+# shell: bash
+# run: |
+# .github/dockcross/dockcross-manylinux_2_28-x64 .github/build_cuda_linux.sh "-DOS_NAME=Linux -DOS_ARCH=x86_64"
+# - name: Upload artifacts
+# uses: actions/upload-artifact@v4
+# with:
+# name: linux-libraries-cuda
+# path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/
build-linux-docker:
name: Build ${{ matrix.target.os }}-${{ matrix.target.arch }}
@@ -94,7 +97,7 @@ jobs:
build-win-native:
name: Build ${{ matrix.target.os }}-${{ matrix.target.arch }}
- runs-on: windows-latest
+ runs-on: windows-2019
strategy:
fail-fast: false
matrix:
@@ -102,23 +105,24 @@ jobs:
- {
os: Windows,
arch: x86_64,
- cmake: '-G "Visual Studio 17 2022" -A "x64"'
- }
- - {
- os: Windows,
- arch: aarch64,
- cmake: '-G "Visual Studio 17 2022" -A "ARM64"'
+ cmake: '-G "Visual Studio 16 2019" -A "x64"'
}
- {
os: Windows,
arch: x86,
- cmake: '-G "Visual Studio 17 2022" -A "Win32"'
- }
- - {
- os: Windows,
- arch: arm,
- cmake: '-G "Visual Studio 17 2022" -A "ARM"'
+ cmake: '-G "Visual Studio 16 2019" -A "Win32"'
}
+# MSVC aarch64 builds no longer work with llama.cpp (requires clang instead)
+# - {
+# os: Windows,
+# arch: aarch64,
+# cmake: '-G "Visual Studio 16 2019" -A "ARM64"'
+# }
+# - {
+# os: Windows,
+# arch: arm,
+# cmake: '-G "Visual Studio 16 2019" -A "ARM"'
+# }
steps:
- uses: actions/checkout@v4
- name: Build libraries
@@ -142,8 +146,10 @@ jobs:
with:
name: Linux-x86_64-libraries
path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/
- - name: Download model
+ - name: Download text generation model
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
+ - name: Download reranking model
+ run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
- uses: actions/setup-java@v4
with:
distribution: 'zulu'
@@ -193,7 +199,7 @@ jobs:
publish:
if: ${{ github.event_name != 'workflow_dispatch' || github.event.inputs.build_only == 'no' }}
- needs: [ test-linux,build-macos-native,build-win-native,build-linux-cuda ]
+ needs: [ test-linux,build-macos-native,build-win-native ] #,build-linux-cuda
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
@@ -202,10 +208,10 @@ jobs:
pattern: "*-libraries"
merge-multiple: true
path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/
- - uses: actions/download-artifact@v4
- with:
- name: linux-libraries-cuda
- path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/
+# - uses: actions/download-artifact@v4
+# with:
+# name: linux-libraries-cuda
+# path: ${{ github.workspace }}/src/main/resources_linux_cuda/de/kherud/llama/
- name: Set up Maven Central Repository
uses: actions/setup-java@v3
with:
diff --git a/.gitignore b/.gitignore
index 8857fd04..274f8687 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
.idea
target
build
+cmake-build-*
.DS_Store
.directory
.vscode
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 847465e6..96c62950 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -6,6 +6,7 @@ include(FetchContent)
set(BUILD_SHARED_LIBS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(BUILD_SHARED_LIBS OFF)
option(LLAMA_VERBOSE "llama: verbose output" OFF)
@@ -20,10 +21,11 @@ FetchContent_MakeAvailable(json)
#################### llama.cpp ####################
+set(LLAMA_BUILD_COMMON ON)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
- GIT_TAG b3534
+ GIT_TAG b4916
)
FetchContent_MakeAvailable(llama.cpp)
@@ -67,7 +69,7 @@ endif()
# include jni.h and jni_md.h
if(NOT DEFINED JNI_INCLUDE_DIRS)
- if(OS_NAME MATCHES "^Linux" OR OS_NAME STREQUAL "Mac")
+ if(OS_NAME MATCHES "^Linux" OR OS_NAME STREQUAL "Mac" OR OS_NAME STREQUAL "Darwin")
set(JNI_INCLUDE_DIRS .github/include/unix)
elseif(OS_NAME STREQUAL "Windows")
set(JNI_INCLUDE_DIRS .github/include/windows)
@@ -102,9 +104,10 @@ target_compile_definitions(jllama PRIVATE
)
if(OS_NAME STREQUAL "Windows")
- set_target_properties(jllama llama ggml PROPERTIES
+ set_target_properties(jllama llama ggml PROPERTIES
RUNTIME_OUTPUT_DIRECTORY_DEBUG ${JLLAMA_DIR}
RUNTIME_OUTPUT_DIRECTORY_RELEASE ${JLLAMA_DIR}
+ RUNTIME_OUTPUT_DIRECTORY_RELWITHDEBINFO ${JLLAMA_DIR}
)
else()
set_target_properties(jllama llama ggml PROPERTIES
diff --git a/README.md b/README.md
index 718ec4be..1bc278b1 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@

-
+
# Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp)
@@ -17,7 +17,7 @@ Inference of Meta's LLaMA model (and others) in pure C/C++.
3. [Android](#importing-in-android)
> [!NOTE]
-> Now with support for Llama 3, Phi-3, and flash attention
+> Now with support for Gemma 3
## Quick Start
@@ -27,18 +27,7 @@ Access this library via Maven:
de.kherud
llama
- 3.4.1
-
-```
-
-Bu default the default library artifact is built only with CPU inference support. To enable CUDA, use a `cuda12-linux-x86-64` maven classifier:
-
-```xml
-
- de.kherud
- llama
- 3.4.1
- cuda12-linux-x86-64
+ 4.1.0
```
@@ -50,11 +39,7 @@ We support CPU inference for the following platforms out of the box:
- Linux x86-64, aarch64
- MacOS x86-64, aarch64 (M-series)
-- Windows x86-64, x64, arm (32 bit)
-
-For GPU inference, we support:
-
-- Linux x86-64 with CUDA 12.1+
+- Windows x86-64, x64
If any of these match your platform, you can include the Maven dependency and get started.
@@ -78,7 +63,7 @@ cmake --build build --config Release
```
> [!TIP]
-> Use `-DGGML_CURL=ON` to download models via Java code using `ModelParameters#setModelUrl(String)`.
+> Use `-DLLAMA_CURL=ON` to download models via Java code using `ModelParameters#setModelUrl(String)`.
All compiled libraries will be put in a resources directory matching your platform, which will appear in the cmake output. For example something like:
@@ -88,13 +73,9 @@ All compiled libraries will be put in a resources directory matching your platfo
#### Library Location
-This project has to load three shared libraries:
-
-- ggml
-- llama
-- jllama
+This project has to load a single shared library `jllama`.
-Note, that the file names vary between operating systems, e.g., `ggml.dll` on Windows, `libggml.so` on Linux, and `libggml.dylib` on macOS.
+Note, that the file name varies between operating systems, e.g., `jllama.dll` on Windows, `jllama.so` on Linux, and `jllama.dylib` on macOS.
The application will search in the following order in the following locations:
@@ -105,14 +86,6 @@ The application will search in the following order in the following locations:
- From the **JAR**: If any of the libraries weren't found yet, the application will try to use a prebuilt shared library.
This of course only works for the [supported platforms](#no-setup-required) .
-Not all libraries have to be in the same location.
-For example, if you already have a llama.cpp and ggml version you can install them as a system library and rely on the jllama library from the JAR.
-This way, you don't have to compile anything.
-
-#### CUDA
-
-On Linux x86-64 with CUDA 12.1+, the library assumes that your CUDA libraries are findable in `java.library.path`. If you have CUDA installed in a non-standard location, then point the `java.library.path` to the directory containing the `libcudart.so.12` library.
-
## Documentation
### Example
@@ -124,8 +97,8 @@ public class Example {
public static void main(String... args) throws IOException {
ModelParameters modelParams = new ModelParameters()
- .setModelFilePath("/path/to/model.gguf")
- .setNGpuLayers(43);
+ .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf")
+ .setGpuLayers(43);
String system = "This is a conversation between User and Llama, a friendly chatbot.\n" +
"Llama is helpful, kind, honest, good at writing, and never fails to answer any " +
@@ -144,8 +117,8 @@ public class Example {
InferenceParameters inferParams = new InferenceParameters(prompt)
.setTemperature(0.7f)
.setPenalizeNl(true)
- .setMirostat(InferenceParameters.MiroStat.V2)
- .setAntiPrompt("\n");
+ .setMiroStat(MiroStat.V2)
+ .setStopStrings("User:");
for (LlamaOutput output : model.generate(inferParams)) {
System.out.print(output);
prompt += output;
@@ -165,7 +138,7 @@ model to your prompt in order to extend the context. If there is repeated conten
cache this, to improve performance.
```java
-ModelParameters modelParams = new ModelParameters().setModelFilePath("/path/to/model.gguf");
+ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf");
InferenceParameters inferParams = new InferenceParameters("Tell me a joke.");
try (LlamaModel model = new LlamaModel(modelParams)) {
// Stream a response and access more information about each output.
@@ -197,9 +170,8 @@ for every inference task. All non-specified options have sensible defaults.
```java
ModelParameters modelParams = new ModelParameters()
- .setModelFilePath("/path/to/model.gguf")
- .setLoraAdapter("/path/to/lora/adapter")
- .setLoraBase("/path/to/lora/base");
+ .setModel("/path/to/model.gguf")
+ .addLoraAdapter("/path/to/lora/adapter");
String grammar = """
root ::= (expr "=" term "\\n")+
expr ::= term ([-+*/] term)*
@@ -234,7 +206,7 @@ LlamaModel.setLogger(null, (level, message) -> {});
## Importing in Android
You can use this library in Android project.
-1. Add java-llama.cpp as a submodule in your android `app` project directory
+1. Add java-llama.cpp as a submodule in your an droid `app` project directory
```shell
git submodule add https://github.com/kherud/java-llama.cpp
```
diff --git a/pom.xml b/pom.xml
index 68674de9..67b366ee 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1,14 +1,16 @@
-
4.0.0
de.kherud
llama
- 3.4.1
+ 4.2.0
jar
${project.groupId}:${project.artifactId}
- Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++.
+ Java Bindings for llama.cpp - A Port of Facebook's LLaMA model
+ in C/C++.
https://github.com/kherud/java-llama.cpp
@@ -39,7 +41,8 @@
ossrh
- https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/
+
+ https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/
@@ -71,17 +74,21 @@
maven-compiler-plugin
3.13.0
-
+
gpu
compile
- compile
+
+ compile
+
-h
src/main/cpp
- ${project.build.outputDirectory}_cuda
+
+ ${project.build.outputDirectory}_cuda
@@ -98,10 +105,12 @@
copy-resources
- ${project.build.outputDirectory}_cuda
+
+ ${project.build.outputDirectory}_cuda
- ${basedir}/src/main/resources_linux_cuda/
+
+ ${basedir}/src/main/resources_linux_cuda/
**/*.*
@@ -176,7 +185,8 @@
maven-jar-plugin
3.4.2
-
+
cuda
package
@@ -185,7 +195,8 @@
cuda12-linux-x86-64
- ${project.build.outputDirectory}_cuda
+
+ ${project.build.outputDirectory}_cuda
diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp
index d59f3b77..11c80ae0 100644
--- a/src/main/cpp/jllama.cpp
+++ b/src/main/cpp/jllama.cpp
@@ -1,18 +1,21 @@
#include "jllama.h"
+#include "arg.h"
+#include "json-schema-to-grammar.h"
#include "llama.h"
+#include "log.h"
#include "nlohmann/json.hpp"
#include "server.hpp"
#include
+#include
#include
// We store some references to Java classes and their fields/methods here to speed up things for later and to fail
// early on if anything can't be found. This happens when the JVM loads the shared library (see `JNI_OnLoad`).
// The references remain valid throughout the whole life of the shared library, on `JNI_OnUnload` they are released.
-namespace
-{
+namespace {
JavaVM *g_vm = nullptr;
// classes
@@ -78,8 +81,7 @@ jobject o_log_callback = nullptr;
/**
* Convert a Java string to a std::string
*/
-std::string parse_jstring(JNIEnv *env, jstring java_string)
-{
+std::string parse_jstring(JNIEnv *env, jstring java_string) {
auto *const string_bytes = (jbyteArray)env->CallObjectMethod(java_string, m_get_bytes, o_utf_8);
auto length = (size_t)env->GetArrayLength(string_bytes);
@@ -93,13 +95,38 @@ std::string parse_jstring(JNIEnv *env, jstring java_string)
return string;
}
+char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const jsize length) {
+ auto *const result = static_cast(malloc(length * sizeof(char *)));
+
+ if (result == nullptr) {
+ return nullptr;
+ }
+
+ for (jsize i = 0; i < length; i++) {
+ auto *const javaString = static_cast(env->GetObjectArrayElement(string_array, i));
+ const char *cString = env->GetStringUTFChars(javaString, nullptr);
+ result[i] = strdup(cString);
+ env->ReleaseStringUTFChars(javaString, cString);
+ }
+
+ return result;
+}
+
+void free_string_array(char **array, jsize length) {
+ if (array != nullptr) {
+ for (jsize i = 0; i < length; i++) {
+ free(array[i]);
+ }
+ free(array);
+ }
+}
+
/**
* Since Java expects utf16 but std::strings are utf8, we can't directly use `env->NewString` or `env-NewString`,
* but we directly send the bytes and do the conversion in Java. Unfortunately, there isn't a nice/standardized way to
* do this conversion in C++
*/
-jbyteArray parse_jbytes(JNIEnv *env, const std::string &string)
-{
+jbyteArray parse_jbytes(JNIEnv *env, const std::string &string) {
jsize length = string.size(); // NOLINT(*-narrowing-conversions)
jbyteArray bytes = env->NewByteArray(length);
env->SetByteArrayRegion(bytes, 0, length, reinterpret_cast(string.c_str()));
@@ -109,10 +136,8 @@ jbyteArray parse_jbytes(JNIEnv *env, const std::string &string)
/**
* Map a llama.cpp log level to its Java enumeration option.
*/
-jobject log_level_to_jobject(ggml_log_level level)
-{
- switch (level)
- {
+jobject log_level_to_jobject(ggml_log_level level) {
+ switch (level) {
case GGML_LOG_LEVEL_ERROR:
return o_log_level_error;
case GGML_LOG_LEVEL_WARN:
@@ -128,31 +153,27 @@ jobject log_level_to_jobject(ggml_log_level level)
/**
* Returns the JNIEnv of the current thread.
*/
-JNIEnv *get_jni_env()
-{
+JNIEnv *get_jni_env() {
JNIEnv *env = nullptr;
- if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK)
- {
+ if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) {
throw std::runtime_error("Thread is not attached to the JVM");
}
return env;
}
+bool log_json;
+std::function log_callback;
+
/**
* Invoke the log callback if there is any.
*/
-void log_callback_trampoline(ggml_log_level level, const char *text, void *user_data)
-{
- if (log_callback != nullptr)
- {
+void log_callback_trampoline(ggml_log_level level, const char *text, void *user_data) {
+ if (log_callback != nullptr) {
log_callback(level, text, user_data);
}
}
} // namespace
-bool log_json;
-std::function log_callback;
-
/**
* The VM calls JNI_OnLoad when the native library is loaded (for example, through `System.loadLibrary`).
* `JNI_OnLoad` must return the JNI version needed by the native library.
@@ -161,13 +182,11 @@ std::function log_callback;
* only requires JNI version `JNI_VERSION_1_1`. If the VM does not recognize the version number returned by
`JNI_OnLoad`, the VM will unload the library and act as if the library was never loaded.
*/
-JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
-{
+JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
g_vm = vm;
JNIEnv *env = nullptr;
- if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1))
- {
+ if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) {
goto error;
}
@@ -192,8 +211,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
if (!(c_llama_model && c_llama_iterator && c_standard_charsets && c_output && c_string && c_hash_map && c_map &&
c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level &&
- c_log_format && c_error_oom))
- {
+ c_log_format && c_error_oom)) {
goto error;
}
@@ -221,8 +239,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
cc_integer = env->GetMethodID(c_integer, "", "(I)V");
cc_float = env->GetMethodID(c_float, "", "(F)V");
- if (!(cc_output && cc_hash_map && cc_integer && cc_float))
- {
+ if (!(cc_output && cc_hash_map && cc_integer && cc_float)) {
goto error;
}
@@ -240,8 +257,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
m_biconsumer_accept = env->GetMethodID(c_biconsumer, "accept", "(Ljava/lang/Object;Ljava/lang/Object;)V");
if (!(m_get_bytes && m_entry_set && m_set_iterator && m_iterator_has_next && m_iterator_next && m_entry_key &&
- m_entry_value && m_map_put && m_int_value && m_float_value && m_biconsumer_accept))
- {
+ m_entry_value && m_map_put && m_int_value && m_float_value && m_biconsumer_accept)) {
goto error;
}
@@ -258,8 +274,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lde/kherud/llama/args/LogFormat;");
if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next && f_log_level_debug && f_log_level_info &&
- f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text))
- {
+ f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text)) {
goto error;
}
@@ -272,8 +287,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
o_log_format_text = env->GetStaticObjectField(c_log_format, f_log_format_text);
if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error &&
- o_log_format_json && o_log_format_text))
- {
+ o_log_format_json && o_log_format_text)) {
goto error;
}
@@ -285,8 +299,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
o_log_format_json = env->NewGlobalRef(o_log_format_json);
o_log_format_text = env->NewGlobalRef(o_log_format_text);
- if (env->ExceptionCheck())
- {
+ if (env->ExceptionCheck()) {
env->ExceptionDescribe();
goto error;
}
@@ -310,12 +323,10 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved)
* Note that `JNI_OnLoad` and `JNI_OnUnload` are two functions optionally supplied by JNI libraries, not exported from
* the VM.
*/
-JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved)
-{
+JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) {
JNIEnv *env = nullptr;
- if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_6))
- {
+ if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_6)) {
return;
}
@@ -344,63 +355,53 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved)
env->DeleteGlobalRef(o_log_format_json);
env->DeleteGlobalRef(o_log_format_text);
- if (o_log_callback != nullptr)
- {
+ if (o_log_callback != nullptr) {
env->DeleteGlobalRef(o_log_callback);
}
llama_backend_free();
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jstring jparams)
-{
- gpt_params params;
-
- auto *ctx_server = new server_context();
-
- std::string c_params = parse_jstring(env, jparams);
- json json_params = json::parse(c_params);
- server_params_parse(json_params, params);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) {
+ common_params params;
- if (json_value(json_params, "disable_log", false))
- {
- log_disable();
- }
- else
- {
- log_enable();
+ const jsize argc = env->GetArrayLength(jparams);
+ char **argv = parse_string_array(env, jparams, argc);
+ if (argv == nullptr) {
+ return;
}
- if (!params.system_prompt.empty())
- {
- ctx_server->system_prompt_set(params.system_prompt);
+ const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER);
+ free_string_array(argv, argc);
+ if (!parsed_params) {
+ return;
}
- if (params.model_alias == "unknown")
- {
- params.model_alias = params.model;
- }
+ SRV_INF("loading model '%s'\n", params.model.c_str());
- llama_numa_init(params.numa);
+ common_init();
+
+ // struct that contains llama context and inference
+ auto *ctx_server = new server_context();
- LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}});
+ llama_numa_init(params.numa);
- LOG_INFO("system info", {
- {"n_threads", params.n_threads},
- {"n_threads_batch", params.n_threads_batch},
- {"total_threads", std::thread::hardware_concurrency()},
- {"system_info", llama_print_system_info()},
- });
+ LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads,
+ params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
+ LOG_INF("\n");
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+ LOG_INF("\n");
std::atomic state{SERVER_STATE_LOADING_MODEL};
// Necessary similarity of prompt for slot selection
ctx_server->slot_prompt_similarity = params.slot_prompt_similarity;
+ LOG_INF("%s: loading model\n", __func__);
+
// load the model
- if (!ctx_server->load_model(params))
- {
- state.store(SERVER_STATE_ERROR);
+ if (!ctx_server->load_model(params)) {
+ llama_backend_free();
env->ThrowNew(c_llama_error, "could not load model from given file path");
return;
}
@@ -408,60 +409,70 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
ctx_server->init();
state.store(SERVER_STATE_READY);
- LOG_INFO("model loaded", {});
+ LOG_INF("%s: model loaded\n", __func__);
const auto model_meta = ctx_server->model_meta();
- // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
- if (params.chat_template.empty())
- {
- if (!ctx_server->validate_model_chat_template())
- {
- LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This "
- "may cause the model to output suboptimal responses",
- {});
- params.chat_template = "chatml";
+ if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) {
+ SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
+ auto params_dft = params;
+
+ params_dft.devices = params.speculative.devices;
+ params_dft.hf_file = params.speculative.hf_file;
+ params_dft.hf_repo = params.speculative.hf_repo;
+ params_dft.model = params.speculative.model;
+ params_dft.model_url = params.speculative.model_url;
+ params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx;
+ params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
+ params_dft.n_parallel = 1;
+
+ common_init_result llama_init_dft = common_init_from_params(params_dft);
+
+ llama_model *model_dft = llama_init_dft.model.get();
+
+ if (model_dft == nullptr) {
+ SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str());
}
- }
- // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
- if (params.chat_template.empty())
- {
- if (!ctx_server->validate_model_chat_template())
- {
- LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This "
- "may cause the model to output suboptimal responses",
- {});
- params.chat_template = "chatml";
+ if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) {
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
+ params.speculative.model.c_str(), params.model.c_str());
}
+
+ const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
+
+ ctx_server->cparams_dft = common_context_params_to_llama(params_dft);
+ ctx_server->cparams_dft.n_batch = n_ctx_dft;
+
+ // force F16 KV cache for the draft model for extra performance
+ ctx_server->cparams_dft.type_k = GGML_TYPE_F16;
+ ctx_server->cparams_dft.type_v = GGML_TYPE_F16;
+
+ // the context is not needed - we will create one for each slot
+ llama_init_dft.context.reset();
}
// print sample chat example to make it clear which template is used
- {
- LOG_INFO("chat template",
- {
- {"chat_example", llama_chat_format_example(ctx_server->model, params.chat_template)},
- {"built_in", params.chat_template.empty()},
- });
- }
+ LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
+ common_chat_templates_source(ctx_server->chat_templates.get()),
+ common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str());
+
+ // print sample chat example to make it clear which template is used
+ // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
+ // common_chat_templates_source(ctx_server->chat_templates.get()),
+ // common_chat_format_example(*ctx_server->chat_templates.template_default,
+ // ctx_server->params_base.use_jinja) .c_str());
ctx_server->queue_tasks.on_new_task(
std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1));
- ctx_server->queue_tasks.on_finish_multitask(
- std::bind(&server_context::on_finish_multitask, ctx_server, std::placeholders::_1));
ctx_server->queue_tasks.on_update_slots(std::bind(&server_context::update_slots, ctx_server));
- ctx_server->queue_results.on_multitask_update(std::bind(&server_queue::update_multitask, &ctx_server->queue_tasks,
- std::placeholders::_1, std::placeholders::_2,
- std::placeholders::_3));
std::thread t([ctx_server]() {
JNIEnv *env;
jint res = g_vm->GetEnv((void **)&env, JNI_VERSION_1_6);
- if (res == JNI_EDETACHED)
- {
+ if (res == JNI_EDETACHED) {
res = g_vm->AttachCurrentThread((void **)&env, nullptr);
- if (res != JNI_OK)
- {
+ if (res != JNI_OK) {
throw std::runtime_error("Failed to attach thread to JVM");
}
}
@@ -472,60 +483,95 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
env->SetLongField(obj, f_model_pointer, reinterpret_cast(ctx_server));
}
-JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams)
-{
+JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
std::string c_params = parse_jstring(env, jparams);
- json json_params = json::parse(c_params);
- const bool infill = json_params.contains("input_prefix") || json_params.contains("input_suffix");
+ json data = json::parse(c_params);
+
+ server_task_type type = SERVER_TASK_TYPE_COMPLETION;
+
+ if (data.contains("input_prefix") || data.contains("input_suffix")) {
+ type = SERVER_TASK_TYPE_INFILL;
+ }
+
+ auto completion_id = gen_chatcmplid();
+ std::vector tasks;
+
+ try {
+ const auto &prompt = data.at("prompt");
- if (json_params.value("use_chat_template", false))
- {
- json chat;
- chat.push_back({{"role", "system"}, {"content", ctx_server->system_prompt}});
- chat.push_back({{"role", "user"}, {"content", json_params["prompt"]}});
- json_params["prompt"] = format_chat(ctx_server->model, ctx_server->params.chat_template, chat);
+ std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);
+
+ tasks.reserve(tokenized_prompts.size());
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task = server_task(type);
+
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data);
+ task.id_selected_slot = json_value(data, "id_slot", -1);
+
+ // OAI-compat
+ task.params.oaicompat = OAICOMPAT_TYPE_NONE;
+ task.params.oaicompat_cmpl_id = completion_id;
+ // oaicompat_model is already populated by params_from_json_cmpl
+
+ tasks.push_back(task);
+ }
+ } catch (const std::exception &e) {
+ const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
+ env->ThrowNew(c_llama_error, err.dump().c_str());
+ return 0;
+ }
+
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
+
+ const auto task_ids = server_task::get_list_id(tasks);
+
+ if (task_ids.size() != 1) {
+ env->ThrowNew(c_llama_error, "multitasking currently not supported");
+ return 0;
}
- const int id_task = ctx_server->queue_tasks.get_new_id();
- ctx_server->queue_results.add_waiting_task_id(id_task);
- ctx_server->request_completion(id_task, -1, json_params, infill, false);
+ return *task_ids.begin();
+}
- return id_task;
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env, jobject obj, jint id_task) {
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+ ctx_server->queue_results.remove_waiting_task_id(id_task);
}
-JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task)
-{
+JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- server_task_result result = ctx_server->queue_results.recv(id_task);
+ server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
- if (result.error)
- {
- std::string response = result.data["message"].get();
+ if (result->is_error()) {
+ std::string response = result->to_json()["message"].get();
ctx_server->queue_results.remove_waiting_task_id(id_task);
env->ThrowNew(c_llama_error, response.c_str());
return nullptr;
}
+ const auto out_res = result->to_json();
- std::string response = result.data["content"].get();
- if (result.stop)
- {
+ std::string response = out_res["content"].get();
+ if (result->is_stop()) {
ctx_server->queue_results.remove_waiting_task_id(id_task);
}
jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
- if (result.data.contains("completion_probabilities"))
- {
- auto completion_probabilities = result.data["completion_probabilities"];
- for (const auto &entry : completion_probabilities)
- {
+ if (out_res.contains("completion_probabilities")) {
+ auto completion_probabilities = out_res["completion_probabilities"];
+ for (const auto &entry : completion_probabilities) {
auto probs = entry["probs"];
- for (const auto &tp : probs)
- {
+ for (const auto &tp : probs) {
std::string tok_str = tp["tok_str"];
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
float prob = tp["prob"];
@@ -536,18 +582,15 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE
}
}
}
-
jbyteArray jbytes = parse_jbytes(env, response);
- return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result.stop);
+ return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result->is_stop());
}
-JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt)
-{
+JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- if (!ctx_server->params.embedding)
- {
+ if (!ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
"model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))");
return nullptr;
@@ -555,46 +598,187 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
const std::string prompt = parse_jstring(env, jprompt);
- const int id_task = ctx_server->queue_tasks.get_new_id();
- ctx_server->queue_results.add_waiting_task_id(id_task);
- ctx_server->request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
+ SRV_INF("Calling embedding '%s'\n", prompt.c_str());
- server_task_result result = ctx_server->queue_results.recv(id_task);
- ctx_server->queue_results.remove_waiting_task_id(id_task);
- if (result.error)
- {
- std::string response = result.data["message"].get();
+ const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true);
+ std::vector tasks;
+
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = 0;
+ task.prompt_tokens = std::move(tokens);
+
+ // OAI-compat
+ task.params.oaicompat = OAICOMPAT_TYPE_NONE;
+
+ tasks.push_back(task);
+
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
+
+ std::unordered_set task_ids = server_task::get_list_id(tasks);
+ const auto id_task = *task_ids.begin();
+ json responses = json::array();
+
+ json error = nullptr;
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
+
+ json response_str = result->to_json();
+ if (result->is_error()) {
+ std::string response = result->to_json()["message"].get();
+ ctx_server->queue_results.remove_waiting_task_id(id_task);
env->ThrowNew(c_llama_error, response.c_str());
return nullptr;
}
- std::vector embedding = result.data["embedding"].get>();
- jsize embedding_size = embedding.size(); // NOLINT(*-narrowing-conversions)
+ if (result->is_stop()) {
+ ctx_server->queue_results.remove_waiting_task_id(id_task);
+ }
+
+ const auto out_res = result->to_json();
+
+ // Extract "embedding" as a vector of vectors (2D array)
+ std::vector> embedding = out_res["embedding"].get>>();
+
+ // Get total number of rows in the embedding
+ jsize embedding_rows = embedding.size();
+
+ // Get total number of columns in the first row (assuming all rows are of equal length)
+ jsize embedding_cols = embedding_rows > 0 ? embedding[0].size() : 0;
+
+ SRV_INF("Embedding has %d rows and %d columns\n", embedding_rows, embedding_cols);
+
+ // Ensure embedding is not empty
+ if (embedding.empty() || embedding[0].empty()) {
+ env->ThrowNew(c_error_oom, "embedding array is empty");
+ return nullptr;
+ }
+
+ // Extract only the first row
+ const std::vector &first_row = embedding[0]; // Reference to avoid copying
- jfloatArray j_embedding = env->NewFloatArray(embedding_size);
- if (j_embedding == nullptr)
- {
+ // Create a new float array in JNI
+ jfloatArray j_embedding = env->NewFloatArray(embedding_cols);
+ if (j_embedding == nullptr) {
env->ThrowNew(c_error_oom, "could not allocate embedding");
return nullptr;
}
- env->SetFloatArrayRegion(j_embedding, 0, embedding_size, reinterpret_cast(embedding.data()));
+ // Copy the first row into the JNI float array
+ env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, reinterpret_cast(first_row.data()));
return j_embedding;
}
-JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt)
-{
+JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt,
+ jobjectArray documents) {
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+
+ if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error,
+ "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
+ return nullptr;
+ }
+
+ const std::string prompt = parse_jstring(env, jprompt);
+
+ const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);
+
+ json responses = json::array();
+
+ std::vector tasks;
+ const jsize amount_documents = env->GetArrayLength(documents);
+ auto *document_array = parse_string_array(env, documents, amount_documents);
+ auto document_vector = std::vector(document_array, document_array + amount_documents);
+ free_string_array(document_array, amount_documents);
+
+ std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true);
+
+ tasks.reserve(tokenized_docs.size());
+ for (int i = 0; i < tokenized_docs.size(); i++) {
+ auto task = server_task(SERVER_TASK_TYPE_RERANK);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+ task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
+ tasks.push_back(task);
+ }
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
+
+ // get the result
+ std::unordered_set task_ids = server_task::get_list_id(tasks);
+ std::vector results(task_ids.size());
+
+ // Create a new HashMap instance
+ jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
+ if (o_probabilities == nullptr) {
+ env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
+ return nullptr;
+ }
+
+ for (int i = 0; i < (int)task_ids.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
+ if (result->is_error()) {
+ auto response = result->to_json()["message"].get();
+ for (const int id_task : task_ids) {
+ ctx_server->queue_results.remove_waiting_task_id(id_task);
+ }
+ env->ThrowNew(c_llama_error, response.c_str());
+ return nullptr;
+ }
+
+ const auto out_res = result->to_json();
+
+ if (result->is_stop()) {
+ for (const int id_task : task_ids) {
+ ctx_server->queue_results.remove_waiting_task_id(id_task);
+ }
+ }
+
+ int index = out_res["index"].get();
+ float score = out_res["score"].get();
+ std::string tok_str = document_vector[index];
+ jstring jtok_str = env->NewStringUTF(tok_str.c_str());
+
+ jobject jprob = env->NewObject(c_float, cc_float, score);
+ env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
+ env->DeleteLocalRef(jtok_str);
+ env->DeleteLocalRef(jprob);
+ }
+ jbyteArray jbytes = parse_jbytes(env, prompt);
+ return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
+}
+
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+
+ std::string c_params = parse_jstring(env, jparams);
+ json data = json::parse(c_params);
+
+ json templateData =
+ oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja,
+ ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
+ std::string tok_str = templateData.at("prompt");
+ jstring jtok_str = env->NewStringUTF(tok_str.c_str());
+
+ return jtok_str;
+}
+
+JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
const std::string c_prompt = parse_jstring(env, jprompt);
- std::vector tokens = ctx_server->tokenize(c_prompt, false);
+
+ llama_tokens tokens = tokenize_mixed(ctx_server->vocab, c_prompt, false, true);
jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions)
jintArray java_tokens = env->NewIntArray(token_size);
- if (java_tokens == nullptr)
- {
+ if (java_tokens == nullptr) {
env->ThrowNew(c_error_oom, "could not allocate token memory");
return nullptr;
}
@@ -605,8 +789,7 @@ JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env,
}
JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *env, jobject obj,
- jintArray java_tokens)
-{
+ jintArray java_tokens) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
@@ -620,39 +803,33 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv
return parse_jbytes(env, text);
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobject obj)
-{
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobject obj) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
ctx_server->queue_tasks.terminate();
- delete ctx_server;
+ // delete ctx_server;
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *env, jobject obj, jint id_task)
-{
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *env, jobject obj, jint id_task) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- ctx_server->request_cancel(id_task);
+ std::unordered_set id_tasks = {id_task};
+ ctx_server->cancel_tasks(id_tasks);
ctx_server->queue_results.remove_waiting_task_id(id_task);
}
JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jclass clazz, jobject log_format,
- jobject jcallback)
-{
- if (o_log_callback != nullptr)
- {
+ jobject jcallback) {
+ if (o_log_callback != nullptr) {
env->DeleteGlobalRef(o_log_callback);
}
log_json = env->IsSameObject(log_format, o_log_format_json);
- if (jcallback == nullptr)
- {
+ if (jcallback == nullptr) {
log_callback = nullptr;
llama_log_set(nullptr, nullptr);
- }
- else
- {
+ } else {
o_log_callback = env->NewGlobalRef(jcallback);
log_callback = [](enum ggml_log_level level, const char *text, void *user_data) {
JNIEnv *env = get_jni_env();
@@ -661,9 +838,16 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jc
env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message);
env->DeleteLocalRef(message);
};
- if (!log_json)
- {
+ if (!log_json) {
llama_log_set(log_callback_trampoline, nullptr);
}
}
}
+
+JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *env, jclass clazz,
+ jstring j_schema) {
+ const std::string c_schema = parse_jstring(env, j_schema);
+ nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema);
+ const std::string c_grammar = json_schema_to_grammar(c_schema_json);
+ return parse_jbytes(env, c_grammar);
+}
diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h
index 2fd0529e..dc17fa83 100644
--- a/src/main/cpp/jllama.h
+++ b/src/main/cpp/jllama.h
@@ -12,72 +12,91 @@ extern "C" {
* Method: embed
* Signature: (Ljava/lang/String;)[F
*/
-JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed
- (JNIEnv *, jobject, jstring);
+JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *, jobject, jstring);
/*
* Class: de_kherud_llama_LlamaModel
* Method: encode
* Signature: (Ljava/lang/String;)[I
*/
-JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode
- (JNIEnv *, jobject, jstring);
+JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *, jobject, jstring);
/*
* Class: de_kherud_llama_LlamaModel
* Method: setLogger
* Signature: (Lde/kherud/llama/args/LogFormat;Ljava/util/function/BiConsumer;)V
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger
- (JNIEnv *, jclass, jobject, jobject);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *, jclass, jobject, jobject);
/*
* Class: de_kherud_llama_LlamaModel
* Method: requestCompletion
* Signature: (Ljava/lang/String;)I
*/
-JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion
- (JNIEnv *, jobject, jstring);
+JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *, jobject, jstring);
/*
* Class: de_kherud_llama_LlamaModel
* Method: receiveCompletion
* Signature: (I)Lde/kherud/llama/LlamaOutput;
*/
-JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion
- (JNIEnv *, jobject, jint);
+JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *, jobject, jint);
/*
* Class: de_kherud_llama_LlamaModel
* Method: cancelCompletion
* Signature: (I)V
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion
- (JNIEnv *, jobject, jint);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *, jobject, jint);
/*
* Class: de_kherud_llama_LlamaModel
* Method: decodeBytes
* Signature: ([I)[B
*/
-JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes
- (JNIEnv *, jobject, jintArray);
+JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *, jobject, jintArray);
/*
* Class: de_kherud_llama_LlamaModel
* Method: loadModel
- * Signature: (Ljava/lang/String;)V
+ * Signature: ([Ljava/lang/String;)V
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel
- (JNIEnv *, jobject, jstring);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *, jobject, jobjectArray);
/*
* Class: de_kherud_llama_LlamaModel
* Method: delete
* Signature: ()V
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete
- (JNIEnv *, jobject);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *, jobject);
+
+/*
+ * Class: de_kherud_llama_LlamaModel
+ * Method: releaseTask
+ * Signature: (I)V
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *, jobject, jint);
+
+/*
+ * Class: de_kherud_llama_LlamaModel
+ * Method: jsonSchemaToGrammarBytes
+ * Signature: (Ljava/lang/String;)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *, jclass, jstring);
+
+/*
+ * Class: de_kherud_llama_LlamaModel
+ * Method: rerank
+ * Signature: (Ljava/lang/String;[Ljava/lang/String;)Lde/kherud/llama/LlamaOutput;
+ */
+JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray);
+
+/*
+ * Class: de_kherud_llama_LlamaModel
+ * Method: applyTemplate
+ * Signature: (Ljava/lang/String;)Ljava/lang/String;;
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *, jobject, jstring);
#ifdef __cplusplus
}
diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp
index 0601dac4..9686f2af 100644
--- a/src/main/cpp/server.hpp
+++ b/src/main/cpp/server.hpp
@@ -1,118 +1,1186 @@
#include "utils.hpp"
-#include "common.h"
-#include "grammar-parser.h"
-#include "llama.h"
-
-#include "nlohmann/json.hpp"
+#include "json-schema-to-grammar.h"
+#include "sampling.h"
+#include "speculative.h"
#include
#include
+#include
#include
#include
+#include
#include
#include
-#include
#include
#include
+#include
+#include
using json = nlohmann::ordered_json;
-enum stop_type
-{
- STOP_TYPE_FULL,
- STOP_TYPE_PARTIAL,
-};
+constexpr int HTTP_POLLING_SECONDS = 1;
-enum slot_state
-{
- SLOT_STATE_IDLE,
- SLOT_STATE_PROCESSING,
+enum stop_type {
+ STOP_TYPE_NONE,
+ STOP_TYPE_EOS,
+ STOP_TYPE_WORD,
+ STOP_TYPE_LIMIT,
};
-enum slot_command
-{
- SLOT_COMMAND_NONE,
- SLOT_COMMAND_LOAD_PROMPT,
- SLOT_COMMAND_RELEASE,
+// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
+enum slot_state {
+ SLOT_STATE_IDLE,
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it
+ // with launch_slot_with_task in the future
+ SLOT_STATE_PROCESSING_PROMPT,
+ SLOT_STATE_DONE_PROMPT,
+ SLOT_STATE_GENERATING,
};
-enum server_state
-{
+enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
- SERVER_STATE_ERROR // An error occurred, load_model failed
};
-enum server_task_type
-{
+enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
+ SERVER_TASK_TYPE_EMBEDDING,
+ SERVER_TASK_TYPE_RERANK,
+ SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
+ SERVER_TASK_TYPE_SET_LORA,
};
-struct server_task
-{
- int id = -1; // to be filled by server_queue
- int id_multi = -1;
- int id_target = -1;
+enum oaicompat_type {
+ OAICOMPAT_TYPE_NONE,
+ OAICOMPAT_TYPE_CHAT,
+ OAICOMPAT_TYPE_COMPLETION,
+ OAICOMPAT_TYPE_EMBEDDING,
+};
+
+// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
+enum error_type {
+ ERROR_TYPE_INVALID_REQUEST,
+ ERROR_TYPE_AUTHENTICATION,
+ ERROR_TYPE_SERVER,
+ ERROR_TYPE_NOT_FOUND,
+ ERROR_TYPE_PERMISSION,
+ ERROR_TYPE_UNAVAILABLE, // custom error
+ ERROR_TYPE_NOT_SUPPORTED, // custom error
+};
+
+struct slot_params {
+ bool stream = true;
+ bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
+ bool return_tokens = false;
+
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_discard =
+ 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
+
+ int64_t t_max_prompt_ms = -1; // TODO: implement
+ int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
+
+ std::vector lora;
+
+ std::vector antiprompt;
+ std::vector response_fields;
+ bool timings_per_token = false;
+ bool post_sampling_probs = false;
+ bool ignore_eos = false;
+
+ struct common_params_sampling sampling;
+ struct common_params_speculative speculative;
+
+ // OAI-compat fields
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+
+ json to_json() const {
+ std::vector samplers;
+ samplers.reserve(sampling.samplers.size());
+ for (const auto &sampler : sampling.samplers) {
+ samplers.emplace_back(common_sampler_type_to_str(sampler));
+ }
+
+ json lora = json::array();
+ for (size_t i = 0; i < this->lora.size(); ++i) {
+ lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
+ }
+
+ auto grammar_triggers = json::array();
+ for (const auto &trigger : sampling.grammar_triggers) {
+ grammar_triggers.push_back(trigger.to_json());
+ }
+
+ return json{
+ {"n_predict", n_predict}, // Server configured n_predict
+ {"seed", sampling.seed},
+ {"temperature", sampling.temp},
+ {"dynatemp_range", sampling.dynatemp_range},
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
+ {"top_k", sampling.top_k},
+ {"top_p", sampling.top_p},
+ {"min_p", sampling.min_p},
+ {"xtc_probability", sampling.xtc_probability},
+ {"xtc_threshold", sampling.xtc_threshold},
+ {"typical_p", sampling.typ_p},
+ {"repeat_last_n", sampling.penalty_last_n},
+ {"repeat_penalty", sampling.penalty_repeat},
+ {"presence_penalty", sampling.penalty_present},
+ {"frequency_penalty", sampling.penalty_freq},
+ {"dry_multiplier", sampling.dry_multiplier},
+ {"dry_base", sampling.dry_base},
+ {"dry_allowed_length", sampling.dry_allowed_length},
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
+ {"dry_sequence_breakers", sampling.dry_sequence_breakers},
+ {"mirostat", sampling.mirostat},
+ {"mirostat_tau", sampling.mirostat_tau},
+ {"mirostat_eta", sampling.mirostat_eta},
+ {"stop", antiprompt},
+ {"max_tokens", n_predict}, // User configured n_predict
+ {"n_keep", n_keep},
+ {"n_discard", n_discard},
+ {"ignore_eos", sampling.ignore_eos},
+ {"stream", stream},
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
+ {"n_probs", sampling.n_probs},
+ {"min_keep", sampling.min_keep},
+ {"grammar", sampling.grammar},
+ {"grammar_lazy", sampling.grammar_lazy},
+ {"grammar_triggers", grammar_triggers},
+ {"preserved_tokens", sampling.preserved_tokens},
+ {"chat_format", common_chat_format_name(oaicompat_chat_format)},
+ {"samplers", samplers},
+ {"speculative.n_max", speculative.n_max},
+ {"speculative.n_min", speculative.n_min},
+ {"speculative.p_min", speculative.p_min},
+ {"timings_per_token", timings_per_token},
+ {"post_sampling_probs", post_sampling_probs},
+ {"lora", lora},
+ };
+ }
+};
+
+struct server_task {
+ int id = -1; // to be filled by server_queue
+ int index = -1; // used when there are multiple prompts (batch request)
server_task_type type;
- json data;
- bool infill = false;
- bool embedding = false;
+ // used by SERVER_TASK_TYPE_CANCEL
+ int id_target = -1;
+
+ // used by SERVER_TASK_TYPE_INFERENCE
+ slot_params params;
+ llama_tokens prompt_tokens;
+ int id_selected_slot = -1;
+
+ // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
+ struct slot_action {
+ int slot_id;
+ std::string filename;
+ std::string filepath;
+ };
+ slot_action slot_action;
+
+ // used by SERVER_TASK_TYPE_METRICS
+ bool metrics_reset_bucket = false;
+
+ // used by SERVER_TASK_TYPE_SET_LORA
+ std::vector set_lora;
+
+ server_task(server_task_type type) : type(type) {}
+
+ static slot_params params_from_json_cmpl(const llama_context *ctx, const common_params ¶ms_base,
+ const json &data) {
+ const llama_model *model = llama_get_model(ctx);
+ const llama_vocab *vocab = llama_model_get_vocab(model);
+
+ slot_params params;
+
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still
+ // override them)
+ slot_params defaults;
+ defaults.sampling = params_base.sampling;
+ defaults.speculative = params_base.speculative;
+
+ // enabling this will output extra debug information in the HTTP responses from the server
+ params.verbose = params_base.verbosity > 9;
+ params.timings_per_token = json_value(data, "timings_per_token", false);
+
+ params.stream = json_value(data, "stream", false);
+ params.cache_prompt = json_value(data, "cache_prompt", true);
+ params.return_tokens = json_value(data, "return_tokens", false);
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO:
+ // implement
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+ params.response_fields = json_value(data, "response_fields", std::vector());
+
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ params.sampling.dry_allowed_length =
+ json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ params.sampling.dry_penalty_last_n =
+ json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
+
+ params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
+ params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
+ params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
+
+ params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
+ params.speculative.n_min = std::max(params.speculative.n_min, 0);
+ params.speculative.n_max = std::max(params.speculative.n_max, 0);
+
+ // Use OpenAI API logprobs only if n_probs wasn't provided
+ if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs) {
+ params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
+ }
+
+ if (data.contains("lora")) {
+ if (data.at("lora").is_array()) {
+ params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
+ } else {
+ throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
+ }
+ } else {
+ params.lora = params_base.lora_adapters;
+ }
+
+ // TODO: add more sanity checks for the input parameters
+
+ if (params.sampling.penalty_last_n < -1) {
+ throw std::runtime_error("Error: repeat_last_n must be >= -1");
+ }
+
+ if (params.sampling.dry_penalty_last_n < -1) {
+ throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
+ }
+
+ if (params.sampling.penalty_last_n == -1) {
+ // note: should be the slot's context and not the full context, but it's ok
+ params.sampling.penalty_last_n = llama_n_ctx(ctx);
+ }
+
+ if (params.sampling.dry_penalty_last_n == -1) {
+ params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
+ }
+
+ if (params.sampling.dry_base < 1.0f) {
+ params.sampling.dry_base = defaults.sampling.dry_base;
+ }
+
+ // sequence breakers for DRY
+ {
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
+ // Ref:
+ // https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+
+ if (data.contains("dry_sequence_breakers")) {
+ params.sampling.dry_sequence_breakers =
+ json_value(data, "dry_sequence_breakers", std::vector());
+ if (params.sampling.dry_sequence_breakers.empty()) {
+ throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
+ }
+ }
+ }
+
+ // process "json_schema" and "grammar"
+ if (data.contains("json_schema") && !data.contains("grammar")) {
+ try {
+ auto schema = json_value(data, "json_schema", json::object());
+ SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
+ params.sampling.grammar = json_schema_to_grammar(schema);
+ SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
+ } catch (const std::exception &e) {
+ throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
+ }
+ } else {
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
+ SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
+ params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
+ SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
+ }
+
+ {
+ auto it = data.find("chat_format");
+ if (it != data.end()) {
+ params.oaicompat_chat_format = static_cast(it->get());
+ SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
+ } else {
+ params.oaicompat_chat_format = defaults.oaicompat_chat_format;
+ }
+ }
+
+ {
+ const auto preserved_tokens = data.find("preserved_tokens");
+ if (preserved_tokens != data.end()) {
+ for (const auto &t : *preserved_tokens) {
+ auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false,
+ /* parse_special= */ true);
+ if (ids.size() == 1) {
+ SRV_DBG("Preserved token: %d\n", ids[0]);
+ params.sampling.preserved_tokens.insert(ids[0]);
+ } else {
+ // This may happen when using a tool call style meant for a model with special tokens to
+ // preserve on a model without said tokens.
+ SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str());
+ }
+ }
+ }
+ const auto grammar_triggers = data.find("grammar_triggers");
+ if (grammar_triggers != data.end()) {
+ for (const auto &t : *grammar_triggers) {
+ auto ct = common_grammar_trigger::from_json(t);
+ if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+ const auto &word = ct.value;
+ auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
+ if (ids.size() == 1) {
+ auto token = ids[0];
+ if (std::find(params.sampling.preserved_tokens.begin(),
+ params.sampling.preserved_tokens.end(),
+ (llama_token)token) == params.sampling.preserved_tokens.end()) {
+ throw std::runtime_error("Grammar trigger word should be marked as preserved token: " +
+ word);
+ }
+ SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
+ common_grammar_trigger trigger;
+ trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
+ trigger.value = (llama_token)token;
+ params.sampling.grammar_triggers.push_back(trigger);
+ } else {
+ SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
+ params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
+ }
+ } else {
+ params.sampling.grammar_triggers.push_back(ct);
+ }
+ }
+ }
+ if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
+ throw std::runtime_error("Error: no triggers set for lazy grammar!");
+ }
+ }
+
+ {
+ params.sampling.logit_bias.clear();
+ params.ignore_eos = json_value(data, "ignore_eos", false);
+
+ const auto &logit_bias = data.find("logit_bias");
+ if (logit_bias != data.end() && logit_bias->is_array()) {
+ const int n_vocab = llama_vocab_n_tokens(vocab);
+ for (const auto &el : *logit_bias) {
+ // TODO: we may want to throw errors here, in case "el" is incorrect
+ if (el.is_array() && el.size() == 2) {
+ float bias;
+ if (el[1].is_number()) {
+ bias = el[1].get();
+ } else if (el[1].is_boolean() && !el[1].get()) {
+ bias = -INFINITY;
+ } else {
+ continue;
+ }
+
+ if (el[0].is_number_integer()) {
+ llama_token tok = el[0].get();
+ if (tok >= 0 && tok < n_vocab) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ } else if (el[0].is_string()) {
+ auto toks = common_tokenize(vocab, el[0].get(), false);
+ for (auto tok : toks) {
+ params.sampling.logit_bias.push_back({tok, bias});
+ }
+ }
+ }
+ }
+ }
+ }
+
+ {
+ params.antiprompt.clear();
+
+ const auto &stop = data.find("stop");
+ if (stop != data.end() && stop->is_array()) {
+ for (const auto &word : *stop) {
+ if (!word.empty()) {
+ params.antiprompt.push_back(word);
+ }
+ }
+ }
+ }
+
+ {
+ const auto samplers = data.find("samplers");
+ if (samplers != data.end()) {
+ if (samplers->is_array()) {
+ params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
+ } else if (samplers->is_string()) {
+ params.sampling.samplers = common_sampler_types_from_chars(samplers->get());
+ }
+ } else {
+ params.sampling.samplers = defaults.sampling.samplers;
+ }
+ }
+
+ std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
+ params.oaicompat_model = json_value(data, "model", model_name);
+
+ return params;
+ }
+
+ // utility function
+ static std::unordered_set get_list_id(const std::vector &tasks) {
+ std::unordered_set ids(tasks.size());
+ for (size_t i = 0; i < tasks.size(); i++) {
+ ids.insert(tasks[i].id);
+ }
+ return ids;
+ }
+};
+
+struct result_timings {
+ int32_t prompt_n = -1;
+ double prompt_ms;
+ double prompt_per_token_ms;
+ double prompt_per_second;
+
+ int32_t predicted_n = -1;
+ double predicted_ms;
+ double predicted_per_token_ms;
+ double predicted_per_second;
+
+ json to_json() const {
+ return {
+ {"prompt_n", prompt_n},
+ {"prompt_ms", prompt_ms},
+ {"prompt_per_token_ms", prompt_per_token_ms},
+ {"prompt_per_second", prompt_per_second},
+
+ {"predicted_n", predicted_n},
+ {"predicted_ms", predicted_ms},
+ {"predicted_per_token_ms", predicted_per_token_ms},
+ {"predicted_per_second", predicted_per_second},
+ };
+ }
};
-struct server_task_result
-{
+struct server_task_result {
int id = -1;
- int id_multi = -1;
+ int id_slot = -1;
+ virtual bool is_error() {
+ // only used by server_task_result_error
+ return false;
+ }
+ virtual bool is_stop() {
+ // only used by server_task_result_cmpl_*
+ return false;
+ }
+ virtual int get_index() { return -1; }
+ virtual json to_json() = 0;
+ virtual ~server_task_result() = default;
+};
- json data;
+// using shared_ptr for polymorphism of server_task_result
+using server_task_result_ptr = std::unique_ptr;
+
+inline std::string stop_type_to_str(stop_type type) {
+ switch (type) {
+ case STOP_TYPE_EOS:
+ return "eos";
+ case STOP_TYPE_WORD:
+ return "word";
+ case STOP_TYPE_LIMIT:
+ return "limit";
+ default:
+ return "none";
+ }
+}
- bool stop;
- bool error;
+struct completion_token_output {
+ llama_token tok;
+ float prob;
+ std::string text_to_send;
+ struct prob_info {
+ llama_token tok;
+ std::string txt;
+ float prob;
+ };
+ std::vector probs;
+
+ json to_json(bool post_sampling_probs) const {
+ json probs_for_token = json::array();
+ for (const auto &p : probs) {
+ std::string txt(p.txt);
+ txt.resize(validate_utf8(txt));
+ probs_for_token.push_back(json{
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.txt)},
+ {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)},
+ });
+ }
+ return probs_for_token;
+ }
+
+ static json probs_vector_to_json(const std::vector &probs, bool post_sampling_probs) {
+ json out = json::array();
+ for (const auto &p : probs) {
+ std::string txt(p.text_to_send);
+ txt.resize(validate_utf8(txt));
+ out.push_back(json{
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.text_to_send)},
+ {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)},
+ {post_sampling_probs ? "top_probs" : "top_logprobs", p.to_json(post_sampling_probs)},
+ });
+ }
+ return out;
+ }
+
+ static float logarithm(float x) {
+ // nlohmann::json converts -inf to null, so we need to prevent that
+ return x == 0.0f ? std::numeric_limits::lowest() : std::log(x);
+ }
+
+ static std::vector str_to_bytes(const std::string &str) {
+ std::vector bytes;
+ for (unsigned char c : str) {
+ bytes.push_back(c);
+ }
+ return bytes;
+ }
};
-struct server_task_multi
-{
- int id = -1;
+struct server_task_result_cmpl_final : server_task_result {
+ int index = 0;
+
+ std::string content;
+ llama_tokens tokens;
+
+ bool stream;
+ result_timings timings;
+ std::string prompt;
+
+ bool truncated;
+ int32_t n_decoded;
+ int32_t n_prompt_tokens;
+ int32_t n_tokens_cached;
+ bool has_new_line;
+ std::string stopping_word;
+ stop_type stop = STOP_TYPE_NONE;
- std::set subtasks_remaining;
- std::vector results;
+ bool post_sampling_probs;
+ std::vector probs_output;
+ std::vector response_fields;
+
+ slot_params generation_params;
+
+ // OAI-compat fields
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+
+ virtual int get_index() override { return index; }
+
+ virtual bool is_stop() override {
+ return true; // in stream mode, final responses are considered stop
+ }
+
+ virtual json to_json() override {
+ switch (oaicompat) {
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
+ }
+ }
+
+ json to_json_non_oaicompat() {
+ json res = json{
+ {"index", index},
+ {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"tokens", stream ? llama_tokens{} : tokens},
+ {"id_slot", id_slot},
+ {"stop", true},
+ {"model", oaicompat_model},
+ {"tokens_predicted", n_decoded},
+ {"tokens_evaluated", n_prompt_tokens},
+ {"generation_settings", generation_params.to_json()},
+ {"prompt", prompt},
+ {"has_new_line", has_new_line},
+ {"truncated", truncated},
+ {"stop_type", stop_type_to_str(stop)},
+ {"stopping_word", stopping_word},
+ {"tokens_cached", n_tokens_cached},
+ {"timings", timings.to_json()},
+ };
+ if (!stream && !probs_output.empty()) {
+ res["completion_probabilities"] =
+ completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
+ }
+ return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
+ }
+
+ json to_json_oaicompat() {
+ std::time_t t = std::time(0);
+ json logprobs = json(nullptr); // OAI default to null
+ if (!stream && probs_output.size() > 0) {
+ logprobs = json{
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+ };
+ }
+ json finish_reason = "length";
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+ finish_reason = "stop";
+ }
+ json res = json{
+ {"choices", json::array({json{
+ {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", finish_reason},
+ }})},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "text_completion"},
+ {"usage", json{{"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}}},
+ {"id", oaicompat_cmpl_id}};
+
+ // extra fields for debugging purposes
+ if (verbose) {
+ res["__verbose"] = to_json_non_oaicompat();
+ }
+ if (timings.prompt_n >= 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+
+ return res;
+ }
+
+ json to_json_oaicompat_chat() {
+ std::string finish_reason = "length";
+ common_chat_msg msg;
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+ SRV_DBG("Parsing chat message: %s\n", content.c_str());
+ msg = common_chat_parse(content, oaicompat_chat_format);
+ finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
+ } else {
+ msg.content = content;
+ }
+
+ json message{
+ {"role", "assistant"},
+ };
+ if (!msg.reasoning_content.empty()) {
+ message["reasoning_content"] = msg.reasoning_content;
+ }
+ if (msg.content.empty() && !msg.tool_calls.empty()) {
+ message["content"] = json();
+ } else {
+ message["content"] = msg.content;
+ }
+ if (!msg.tool_calls.empty()) {
+ auto tool_calls = json::array();
+ for (const auto &tc : msg.tool_calls) {
+ tool_calls.push_back({
+ {"type", "function"},
+ {"function",
+ {
+ {"name", tc.name},
+ {"arguments", tc.arguments},
+ }},
+ {"id", tc.id},
+ });
+ }
+ message["tool_calls"] = tool_calls;
+ }
+
+ json choice{
+ {"finish_reason", finish_reason},
+ {"index", 0},
+ {"message", message},
+ };
+
+ if (!stream && probs_output.size() > 0) {
+ choice["logprobs"] = json{
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+ };
+ }
+
+ std::time_t t = std::time(0);
+
+ json res = json{{"choices", json::array({choice})},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion"},
+ {"usage", json{{"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}}},
+ {"id", oaicompat_cmpl_id}};
+
+ // extra fields for debugging purposes
+ if (verbose) {
+ res["__verbose"] = to_json_non_oaicompat();
+ }
+ if (timings.prompt_n >= 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+
+ return res;
+ }
+
+ json to_json_oaicompat_chat_stream() {
+ std::time_t t = std::time(0);
+ std::string finish_reason = "length";
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+ finish_reason = "stop";
+ }
+
+ json choice = json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}};
+
+ json ret = json{
+ {"choices", json::array({choice})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"},
+ {"usage",
+ json{
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens},
+ }},
+ };
+
+ if (timings.prompt_n >= 0) {
+ ret.push_back({"timings", timings.to_json()});
+ }
+
+ return ret;
+ }
};
-struct slot_params
-{
- bool stream = true;
- bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
+struct server_task_result_cmpl_partial : server_task_result {
+ int index = 0;
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_discard =
- 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
- int32_t n_predict = -1; // new tokens to predict
+ std::string content;
+ llama_tokens tokens;
- std::vector antiprompt;
+ int32_t n_decoded;
+ int32_t n_prompt_tokens;
+
+ bool post_sampling_probs;
+ completion_token_output prob_output;
+ result_timings timings;
+
+ // OAI-compat fields
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+
+ virtual int get_index() override { return index; }
+
+ virtual bool is_stop() override {
+ return false; // in stream mode, partial responses are not considered stop
+ }
+
+ virtual json to_json() override {
+ switch (oaicompat) {
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
+ }
+ }
+
+ json to_json_non_oaicompat() {
+ // non-OAI-compat JSON
+ json res = json{
+ {"index", index},
+ {"content", content},
+ {"tokens", tokens},
+ {"stop", false},
+ {"id_slot", id_slot},
+ {"tokens_predicted", n_decoded},
+ {"tokens_evaluated", n_prompt_tokens},
+ };
+ // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
+ if (timings.prompt_n > 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+ if (!prob_output.probs.empty()) {
+ res["completion_probabilities"] =
+ completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
+ }
+ return res;
+ }
- json input_prefix;
- json input_suffix;
+ json to_json_oaicompat() {
+ std::time_t t = std::time(0);
+ json logprobs = json(nullptr); // OAI default to null
+ if (prob_output.probs.size() > 0) {
+ logprobs = json{
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+ };
+ }
+ json res = json{{"choices", json::array({json{
+ {"text", content},
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", nullptr},
+ }})},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "text_completion"},
+ {"id", oaicompat_cmpl_id}};
+
+ // extra fields for debugging purposes
+ if (verbose) {
+ res["__verbose"] = to_json_non_oaicompat();
+ }
+ if (timings.prompt_n >= 0) {
+ res.push_back({"timings", timings.to_json()});
+ }
+
+ return res;
+ }
+
+ json to_json_oaicompat_chat() {
+ bool first = n_decoded == 0;
+ std::time_t t = std::time(0);
+ json choices;
+
+ if (first) {
+ if (content.empty()) {
+ choices = json::array(
+ {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}});
+ } else {
+ // We have to send this as two updates to conform to openai behavior
+ json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"role", "assistant"}}}}})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"object", "chat.completion.chunk"}};
+
+ json second_ret =
+ json{{"choices",
+ json::array(
+ {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"object", "chat.completion.chunk"}};
+
+ return std::vector({initial_ret, second_ret});
+ }
+ } else {
+ choices = json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta",
+ json{
+ {"content", content},
+ }},
+ }});
+ }
+
+ GGML_ASSERT(choices.size() >= 1);
+
+ if (prob_output.probs.size() > 0) {
+ choices[0]["logprobs"] = json{
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+ };
+ }
+
+ json ret = json{{"choices", choices},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"}};
+
+ if (timings.prompt_n >= 0) {
+ ret.push_back({"timings", timings.to_json()});
+ }
+
+ return std::vector({ret});
+ }
};
-struct server_slot
-{
+struct server_task_result_embd : server_task_result {
+ int index = 0;
+ std::vector> embedding;
+
+ int32_t n_tokens;
+
+ // OAI-compat fields
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+
+ virtual int get_index() override { return index; }
+
+ virtual json to_json() override {
+ return oaicompat == OAICOMPAT_TYPE_EMBEDDING ? to_json_oaicompat() : to_json_non_oaicompat();
+ }
+
+ json to_json_non_oaicompat() {
+ return json{
+ {"index", index},
+ {"embedding", embedding},
+ };
+ }
+
+ json to_json_oaicompat() {
+ return json{
+ {"index", index},
+ {"embedding", embedding[0]},
+ {"tokens_evaluated", n_tokens},
+ };
+ }
+};
+
+struct server_task_result_rerank : server_task_result {
+ int index = 0;
+ float score = -1e6;
+
+ int32_t n_tokens;
+
+ virtual int get_index() override { return index; }
+
+ virtual json to_json() override {
+ return json{
+ {"index", index},
+ {"score", score},
+ {"tokens_evaluated", n_tokens},
+ };
+ }
+};
+
+// this function maybe used outside of server_task_result_error
+static json format_error_response(const std::string &message, const enum error_type type) {
+ std::string type_str;
+ int code = 500;
+ switch (type) {
+ case ERROR_TYPE_INVALID_REQUEST:
+ type_str = "invalid_request_error";
+ code = 400;
+ break;
+ case ERROR_TYPE_AUTHENTICATION:
+ type_str = "authentication_error";
+ code = 401;
+ break;
+ case ERROR_TYPE_NOT_FOUND:
+ type_str = "not_found_error";
+ code = 404;
+ break;
+ case ERROR_TYPE_SERVER:
+ type_str = "server_error";
+ code = 500;
+ break;
+ case ERROR_TYPE_PERMISSION:
+ type_str = "permission_error";
+ code = 403;
+ break;
+ case ERROR_TYPE_NOT_SUPPORTED:
+ type_str = "not_supported_error";
+ code = 501;
+ break;
+ case ERROR_TYPE_UNAVAILABLE:
+ type_str = "unavailable_error";
+ code = 503;
+ break;
+ }
+ return json{
+ {"code", code},
+ {"message", message},
+ {"type", type_str},
+ };
+}
+
+struct server_task_result_error : server_task_result {
+ int index = 0;
+ error_type err_type = ERROR_TYPE_SERVER;
+ std::string err_msg;
+
+ virtual bool is_error() override { return true; }
+
+ virtual json to_json() override { return format_error_response(err_msg, err_type); }
+};
+
+struct server_task_result_metrics : server_task_result {
+ int n_idle_slots;
+ int n_processing_slots;
+ int n_tasks_deferred;
+ int64_t t_start;
+
+ int32_t kv_cache_tokens_count;
+ int32_t kv_cache_used_cells;
+
+ // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
+ uint64_t n_prompt_tokens_processed_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
+
+ uint64_t n_prompt_tokens_processed = 0;
+ uint64_t t_prompt_processing = 0;
+
+ uint64_t n_tokens_predicted = 0;
+ uint64_t t_tokens_generation = 0;
+
+ uint64_t n_decode_total = 0;
+ uint64_t n_busy_slots_total = 0;
+
+ // while we can also use std::vector this requires copying the slot object which can be quite messy
+ // therefore, we use json to temporarily store the slot.to_json() result
+ json slots_data = json::array();
+
+ virtual json to_json() override {
+ return json{
+ {"idle", n_idle_slots},
+ {"processing", n_processing_slots},
+ {"deferred", n_tasks_deferred},
+ {"t_start", t_start},
+
+ {"n_prompt_tokens_processed_total", n_prompt_tokens_processed_total},
+ {"t_tokens_generation_total", t_tokens_generation_total},
+ {"n_tokens_predicted_total", n_tokens_predicted_total},
+ {"t_prompt_processing_total", t_prompt_processing_total},
+
+ {"n_prompt_tokens_processed", n_prompt_tokens_processed},
+ {"t_prompt_processing", t_prompt_processing},
+ {"n_tokens_predicted", n_tokens_predicted},
+ {"t_tokens_generation", t_tokens_generation},
+
+ {"n_decode_total", n_decode_total},
+ {"n_busy_slots_total", n_busy_slots_total},
+
+ {"kv_cache_tokens_count", kv_cache_tokens_count},
+ {"kv_cache_used_cells", kv_cache_used_cells},
+
+ {"slots", slots_data},
+ };
+ }
+};
+
+struct server_task_result_slot_save_load : server_task_result {
+ std::string filename;
+ bool is_save; // true = save, false = load
+
+ size_t n_tokens;
+ size_t n_bytes;
+ double t_ms;
+
+ virtual json to_json() override {
+ if (is_save) {
+ return json{
+ {"id_slot", id_slot}, {"filename", filename}, {"n_saved", n_tokens},
+ {"n_written", n_bytes}, {"timings", {{"save_ms", t_ms}}},
+ };
+ } else {
+ return json{
+ {"id_slot", id_slot},
+ {"filename", filename},
+ {"n_restored", n_tokens},
+ {"n_read", n_bytes},
+ {"timings", {{"restore_ms", t_ms}}},
+ };
+ }
+ }
+};
+
+struct server_task_result_slot_erase : server_task_result {
+ size_t n_erased;
+
+ virtual json to_json() override {
+ return json{
+ {"id_slot", id_slot},
+ {"n_erased", n_erased},
+ };
+ }
+};
+
+struct server_task_result_apply_lora : server_task_result {
+ virtual json to_json() override { return json{{"success", true}}; }
+};
+
+struct server_slot {
int id;
int id_task = -1;
- int id_multi = -1;
+
+ // only used for completion/embedding/infill/rerank
+ server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
+
+ llama_batch batch_spec = {};
+
+ llama_context *ctx = nullptr;
+ llama_context *ctx_dft = nullptr;
+
+ common_speculative *spec = nullptr;
+
+ std::vector lora;
+
+ // the index relative to completion multi-task request
+ size_t index = 0;
struct slot_params params;
slot_state state = SLOT_STATE_IDLE;
- slot_command command = SLOT_COMMAND_NONE;
// used to determine the slot that has been used the longest
int64_t t_last_used = -1;
@@ -125,46 +1193,40 @@ struct server_slot
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
+ // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
- json prompt;
+ // input prompt tokens
+ llama_tokens prompt_tokens;
- // when a task is submitted, we first tokenize the prompt and store it here
- std::vector prompt_tokens;
+ size_t last_nl_pos = 0;
std::string generated_text;
- std::vector cache_tokens;
+ llama_tokens generated_tokens;
+
+ llama_tokens cache_tokens;
+
std::vector generated_token_probs;
- bool infill = false;
- bool embedding = false;
bool has_next_token = true;
+ bool has_new_line = false;
bool truncated = false;
- bool stopped_eos = false;
- bool stopped_word = false;
- bool stopped_limit = false;
-
- bool oaicompat = false;
+ stop_type stop;
- std::string oaicompat_model;
std::string stopping_word;
// sampling
- llama_token sampled;
- struct llama_sampling_params sparams;
- llama_sampling_context *ctx_sampling = nullptr;
json json_schema;
- int32_t ga_i = 0; // group-attention state
- int32_t ga_n = 1; // group-attention factor
- int32_t ga_w = 512; // group-attention width
+ struct common_sampler *smpl = nullptr;
- int32_t n_past_se = 0; // self-extend
+ llama_token sampled;
+
+ common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
// stats
size_t n_sent_text = 0; // number of sent text character
- size_t n_sent_token_probs = 0;
int64_t t_start_process_prompt;
int64_t t_start_generation;
@@ -172,114 +1234,107 @@ struct server_slot
double t_prompt_processing; // ms
double t_token_generation; // ms
- void reset()
- {
+ std::function callback_on_release;
+
+ void reset() {
+ SLT_DBG(*this, "%s", "\n");
+
n_prompt_tokens = 0;
+ last_nl_pos = 0;
generated_text = "";
+ has_new_line = false;
truncated = false;
- stopped_eos = false;
- stopped_word = false;
- stopped_limit = false;
+ stop = STOP_TYPE_NONE;
stopping_word = "";
n_past = 0;
n_sent_text = 0;
- n_sent_token_probs = 0;
- infill = false;
- ga_i = 0;
- n_past_se = 0;
+ task_type = SERVER_TASK_TYPE_COMPLETION;
+ generated_tokens.clear();
generated_token_probs.clear();
}
- bool has_budget(gpt_params &global_params)
- {
- if (params.n_predict == -1 && global_params.n_predict == -1)
- {
+ bool is_non_causal() const {
+ return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
+ }
+
+ bool can_batch_with(server_slot &other_slot) {
+ return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora);
+ }
+
+ bool has_budget(const common_params &global_params) {
+ if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
n_remaining = -1;
- if (params.n_predict != -1)
- {
+ if (params.n_predict != -1) {
n_remaining = params.n_predict - n_decoded;
- }
- else if (global_params.n_predict != -1)
- {
+ } else if (global_params.n_predict != -1) {
n_remaining = global_params.n_predict - n_decoded;
}
return n_remaining > 0; // no budget
}
- bool available() const
- {
- return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
- }
+ bool is_processing() const { return state != SLOT_STATE_IDLE; }
- bool is_processing() const
- {
- return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
- }
+ bool can_speculate() const { return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; }
- void add_token_string(const completion_token_output &token)
- {
- if (command == SLOT_COMMAND_RELEASE)
- {
+ void add_token(const completion_token_output &token) {
+ if (!is_processing()) {
+ SLT_WRN(*this, "%s", "slot is not processing\n");
return;
}
generated_token_probs.push_back(token);
}
- void release()
- {
- if (state == SLOT_STATE_PROCESSING)
- {
+ void release() {
+ if (is_processing()) {
+ SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
+
+ t_last_used = ggml_time_us();
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
- command = SLOT_COMMAND_RELEASE;
+ state = SLOT_STATE_IDLE;
+ callback_on_release(id);
}
}
- json get_formated_timings() const
- {
- return json{
- {"prompt_n", n_prompt_tokens_processed},
- {"prompt_ms", t_prompt_processing},
- {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
- {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
-
- {"predicted_n", n_decoded},
- {"predicted_ms", t_token_generation},
- {"predicted_per_token_ms", t_token_generation / n_decoded},
- {"predicted_per_second", 1e3 / t_token_generation * n_decoded},
- };
+ result_timings get_timings() const {
+ result_timings timings;
+ timings.prompt_n = n_prompt_tokens_processed;
+ timings.prompt_ms = t_prompt_processing;
+ timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
+ timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
+
+ timings.predicted_n = n_decoded;
+ timings.predicted_ms = t_token_generation;
+ timings.predicted_per_token_ms = t_token_generation / n_decoded;
+ timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
+
+ return timings;
}
- size_t find_stopping_strings(const std::string &text, const size_t last_token_size, const stop_type type)
- {
+ size_t find_stopping_strings(const std::string &text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
- for (const std::string &word : params.antiprompt)
- {
+ for (const std::string &word : params.antiprompt) {
size_t pos;
- if (type == STOP_TYPE_FULL)
- {
+ if (is_full_stop) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
- }
- else
- {
+ } else {
+ // otherwise, partial stop
pos = find_partial_stop_string(word, text);
}
- if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos))
- {
- if (type == STOP_TYPE_FULL)
- {
- stopped_word = true;
+ if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
+ if (is_full_stop) {
+ stop = STOP_TYPE_WORD;
stopping_word = word;
has_next_token = false;
}
@@ -290,56 +1345,46 @@ struct server_slot
return stop_pos;
}
- void print_timings() const
- {
- char buffer[512];
-
- double t_token = t_prompt_processing / n_prompt_tokens_processed;
- double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
-
- snprintf(buffer, 512,
- "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
- t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second);
-
- LOG_INFO(buffer, {
- {"id_slot", id},
- {"id_task", id_task},
- {"t_prompt_processing", t_prompt_processing},
- {"n_prompt_tokens_processed", n_prompt_tokens_processed},
- {"t_token", t_token},
- {"n_tokens_second", n_tokens_second},
- });
-
- t_token = t_token_generation / n_decoded;
- n_tokens_second = 1e3 / t_token_generation * n_decoded;
-
- snprintf(buffer, 512,
- "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
- t_token_generation, n_decoded, t_token, n_tokens_second);
-
- LOG_INFO(buffer, {
- {"id_slot", id},
- {"id_task", id_task},
- {"t_token_generation", t_token_generation},
- {"n_decoded", n_decoded},
- {"t_token", t_token},
- {"n_tokens_second", n_tokens_second},
- });
-
- snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
-
- LOG_INFO(buffer, {
- {"id_slot", id},
- {"id_task", id_task},
- {"t_prompt_processing", t_prompt_processing},
- {"t_token_generation", t_token_generation},
- {"t_total", t_prompt_processing + t_token_generation},
- });
+ void print_timings() const {
+ const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
+ const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
+
+ const double t_gen = t_token_generation / n_decoded;
+ const double n_gen_second = 1e3 / t_token_generation * n_decoded;
+
+ SLT_INF(*this,
+ "\n"
+ "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
+ " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
+ " total time = %10.2f ms / %5d tokens\n",
+ t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation,
+ n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation,
+ n_prompt_tokens_processed + n_decoded);
+ }
+
+ json to_json() const {
+ return json{
+ {"id", id},
+ {"id_task", id_task},
+ {"n_ctx", n_ctx},
+ {"speculative", can_speculate()},
+ {"is_processing", is_processing()},
+ {"non_causal", is_non_causal()},
+ {"params", params.to_json()},
+ {"prompt", common_detokenize(ctx, prompt_tokens)},
+ {"next_token",
+ {
+ {"has_next_token", has_next_token},
+ {"has_new_line", has_new_line},
+ {"n_remain", n_remaining},
+ {"n_decoded", n_decoded},
+ {"stopping_word", stopping_word},
+ }},
+ };
}
};
-struct server_metrics
-{
+struct server_metrics {
int64_t t_start = 0;
uint64_t n_prompt_tokens_processed_total = 0;
@@ -353,29 +1398,35 @@ struct server_metrics
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
- void init()
- {
- t_start = ggml_time_us();
- }
+ uint64_t n_decode_total = 0;
+ uint64_t n_busy_slots_total = 0;
- void on_prompt_eval(const server_slot &slot)
- {
+ void init() { t_start = ggml_time_us(); }
+
+ void on_prompt_eval(const server_slot &slot) {
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
t_prompt_processing += slot.t_prompt_processing;
t_prompt_processing_total += slot.t_prompt_processing;
}
- void on_prediction(const server_slot &slot)
- {
+ void on_prediction(const server_slot &slot) {
n_tokens_predicted_total += slot.n_decoded;
n_tokens_predicted += slot.n_decoded;
t_tokens_generation += slot.t_token_generation;
t_tokens_generation_total += slot.t_token_generation;
}
- void reset_bucket()
- {
+ void on_decoded(const std::vector &slots) {
+ n_decode_total++;
+ for (const auto &slot : slots) {
+ if (slot.is_processing()) {
+ n_busy_slots_total++;
+ }
+ }
+ }
+
+ void reset_bucket() {
n_prompt_tokens_processed = 0;
t_prompt_processing = 0;
n_tokens_predicted = 0;
@@ -383,88 +1434,94 @@ struct server_metrics
}
};
-struct server_queue
-{
+struct server_queue {
int id = 0;
bool running;
// queues
- std::vector queue_tasks;
- std::vector queue_tasks_deferred;
-
- std::vector queue_multitasks;
+ std::deque queue_tasks;
+ std::deque queue_tasks_deferred;
std::mutex mutex_tasks;
std::condition_variable condition_tasks;
// callback functions
- std::function callback_new_task;
- std::function callback_finish_multitask;
+ std::function callback_new_task;
std::function callback_update_slots;
// Add a new task to the end of the queue
- int post(server_task task)
- {
+ int post(server_task task, bool front = false) {
std::unique_lock lock(mutex_tasks);
- if (task.id == -1)
- {
- task.id = id++;
- LOG_VERBOSE("new task id", {{"new_id", task.id}});
+ GGML_ASSERT(task.id != -1);
+ // if this is cancel task make sure to clean up pending tasks
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
+ cleanup_pending_task(task.id_target);
+ }
+ QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
+ if (front) {
+ queue_tasks.push_front(std::move(task));
+ } else {
+ queue_tasks.push_back(std::move(task));
+ }
+ condition_tasks.notify_one();
+ return task.id;
+ }
+
+ // multi-task version of post()
+ int post(std::vector &tasks, bool front = false) {
+ std::unique_lock lock(mutex_tasks);
+ for (auto &task : tasks) {
+ if (task.id == -1) {
+ task.id = id++;
+ }
+ // if this is cancel task make sure to clean up pending tasks
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
+ cleanup_pending_task(task.id_target);
+ }
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int)tasks.size(), front);
+ if (front) {
+ queue_tasks.push_front(std::move(task));
+ } else {
+ queue_tasks.push_back(std::move(task));
+ }
}
- queue_tasks.push_back(std::move(task));
condition_tasks.notify_one();
- return task.id;
+ return 0;
}
// Add a new task, but defer until one slot is available
- void defer(server_task task)
- {
+ void defer(server_task task) {
std::unique_lock lock(mutex_tasks);
+ QUE_DBG("defer task, id = %d\n", task.id);
queue_tasks_deferred.push_back(std::move(task));
+ condition_tasks.notify_one();
}
- // Get the next id for creating anew task
- int get_new_id()
- {
+ // Get the next id for creating a new task
+ int get_new_id() {
std::unique_lock lock(mutex_tasks);
int new_id = id++;
- LOG_VERBOSE("new task id", {{"new_id", new_id}});
return new_id;
}
// Register function to process a new task
- void on_new_task(std::function callback)
- {
- callback_new_task = std::move(callback);
- }
-
- // Register function to process a multitask when it is finished
- void on_finish_multitask(std::function callback)
- {
- callback_finish_multitask = std::move(callback);
- }
+ void on_new_task(std::function callback) { callback_new_task = std::move(callback); }
// Register the function to be called when all slots data is ready to be processed
- void on_update_slots(std::function callback)
- {
- callback_update_slots = std::move(callback);
- }
+ void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); }
- // Call when the state of one slot is changed
- void notify_slot_changed()
- {
- // move deferred tasks back to main loop
+ // Call when the state of one slot is changed, it will move one task from deferred to main queue
+ void pop_deferred_task() {
std::unique_lock lock(mutex_tasks);
- for (auto &task : queue_tasks_deferred)
- {
- queue_tasks.push_back(std::move(task));
+ if (!queue_tasks_deferred.empty()) {
+ queue_tasks.emplace_back(std::move(queue_tasks_deferred.front()));
+ queue_tasks_deferred.pop_front();
}
- queue_tasks_deferred.clear();
+ condition_tasks.notify_one();
}
// end the start_loop routine
- void terminate()
- {
+ void terminate() {
std::unique_lock lock(mutex_tasks);
running = false;
condition_tasks.notify_all();
@@ -477,146 +1534,120 @@ struct server_queue
* - Check if multitask is finished
* - Update all slots
*/
- void start_loop()
- {
+ void start_loop() {
running = true;
- while (true)
- {
- LOG_VERBOSE("new task may arrive", {});
+ while (true) {
+ QUE_DBG("%s", "processing new tasks\n");
- while (true)
- {
+ while (true) {
std::unique_lock lock(mutex_tasks);
- if (queue_tasks.empty())
- {
+ if (!running) {
+ QUE_DBG("%s", "terminate\n");
+ return;
+ }
+ if (queue_tasks.empty()) {
lock.unlock();
break;
}
server_task task = queue_tasks.front();
- queue_tasks.erase(queue_tasks.begin());
+ queue_tasks.pop_front();
lock.unlock();
- LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
- callback_new_task(task);
- }
-
- LOG_VERBOSE("update_multitasks", {});
- // check if we have any finished multitasks
- auto queue_iterator = queue_multitasks.begin();
- while (queue_iterator != queue_multitasks.end())
- {
- if (queue_iterator->subtasks_remaining.empty())
- {
- // all subtasks done == multitask is done
- server_task_multi current_multitask = *queue_iterator;
- callback_finish_multitask(current_multitask);
- // remove this multitask
- queue_iterator = queue_multitasks.erase(queue_iterator);
- }
- else
- {
- ++queue_iterator;
- }
+ QUE_DBG("processing task, id = %d\n", task.id);
+ callback_new_task(std::move(task));
}
// all tasks in the current loop is processed, slots data is now ready
- LOG_VERBOSE("callback_update_slots", {});
+ QUE_DBG("%s", "update slots\n");
callback_update_slots();
- LOG_VERBOSE("wait for new task", {});
+ QUE_DBG("%s", "waiting for new tasks\n");
{
std::unique_lock lock(mutex_tasks);
- if (queue_tasks.empty())
- {
- if (!running)
- {
- LOG_VERBOSE("ending start_loop", {});
- return;
- }
+ if (!running) {
+ QUE_DBG("%s", "terminate\n");
+ return;
+ }
+ if (queue_tasks.empty()) {
condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); });
}
}
}
}
- //
- // functions to manage multitasks
- //
-
- // add a multitask by specifying the id of all subtask (subtask is a server_task)
- void add_multitask(int id_multi, std::vector &sub_ids)
- {
- std::lock_guard lock(mutex_tasks);
- server_task_multi multi;
- multi.id = id_multi;
- std::copy(sub_ids.begin(), sub_ids.end(),
- std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
- queue_multitasks.push_back(multi);
- }
-
- // updatethe remaining subtasks, while appending results to multitask
- void update_multitask(int id_multi, int id_sub, server_task_result &result)
- {
- std::lock_guard lock(mutex_tasks);
- for (auto &multitask : queue_multitasks)
- {
- if (multitask.id == id_multi)
- {
- multitask.subtasks_remaining.erase(id_sub);
- multitask.results.push_back(result);
- }
- }
+ private:
+ void cleanup_pending_task(int id_target) {
+ // no need lock because this is called exclusively by post()
+ auto rm_func = [id_target](const server_task &task) { return task.id_target == id_target; };
+ queue_tasks.erase(std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), queue_tasks.end());
+ queue_tasks_deferred.erase(std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
+ queue_tasks_deferred.end());
}
};
-struct server_response
-{
- typedef std::function callback_multitask_t;
- callback_multitask_t callback_update_multitask;
-
+struct server_response {
// for keeping track of all tasks waiting for the result
- std::set waiting_task_ids;
+ std::unordered_set waiting_task_ids;
- // the main result queue
- std::vector queue_results;
+ // the main result queue (using ptr for polymorphism)
+ std::vector queue_results;
std::mutex mutex_results;
std::condition_variable condition_results;
// add the id_task to the list of tasks waiting for response
- void add_waiting_task_id(int id_task)
- {
- LOG_VERBOSE("waiting for task id", {{"id_task", id_task}});
+ void add_waiting_task_id(int id_task) {
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task,
+ (int)waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.insert(id_task);
}
+ void add_waiting_tasks(const std::vector &tasks) {
+ std::unique_lock lock(mutex_results);
+
+ for (const auto &task : tasks) {
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id,
+ (int)waiting_task_ids.size());
+ waiting_task_ids.insert(task.id);
+ }
+ }
+
// when the request is finished, we can remove task associated with it
- void remove_waiting_task_id(int id_task)
- {
- LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
+ void remove_waiting_task_id(int id_task) {
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task,
+ (int)waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.erase(id_task);
+ // make sure to clean up all pending results
+ queue_results.erase(std::remove_if(queue_results.begin(), queue_results.end(),
+ [id_task](const server_task_result_ptr &res) { return res->id == id_task; }),
+ queue_results.end());
}
- // This function blocks the thread until there is a response for this id_task
- server_task_result recv(int id_task)
- {
- while (true)
- {
+ void remove_waiting_task_ids(const std::unordered_set &id_tasks) {
+ std::unique_lock lock(mutex_results);
+
+ for (const auto &id_task : id_tasks) {
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task,
+ (int)waiting_task_ids.size());
+ waiting_task_ids.erase(id_task);
+ }
+ }
+
+ // This function blocks the thread until there is a response for one of the id_tasks
+ server_task_result_ptr recv(const std::unordered_set &id_tasks) {
+ while (true) {
std::unique_lock lock(mutex_results);
condition_results.wait(lock, [&] { return !queue_results.empty(); });
- for (int i = 0; i < (int)queue_results.size(); i++)
- {
- if (queue_results[i].id == id_task)
- {
- assert(queue_results[i].id_multi == -1);
- server_task_result res = queue_results[i];
+ for (size_t i = 0; i < queue_results.size(); i++) {
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+ server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
@@ -626,33 +1657,45 @@ struct server_response
// should never reach here
}
- // Register the function to update multitask
- void on_multitask_update(callback_multitask_t callback)
- {
- callback_update_multitask = std::move(callback);
+ // same as recv(), but have timeout in seconds
+ // if timeout is reached, nullptr is returned
+ server_task_result_ptr recv_with_timeout(const std::unordered_set &id_tasks, int timeout) {
+ while (true) {
+ std::unique_lock lock(mutex_results);
+
+ for (int i = 0; i < (int)queue_results.size(); i++) {
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+ server_task_result_ptr res = std::move(queue_results[i]);
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+
+ std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
+ if (cr_res == std::cv_status::timeout) {
+ return nullptr;
+ }
+ }
+
+ // should never reach here
+ }
+
+ // single-task version of recv()
+ server_task_result_ptr recv(int id_task) {
+ std::unordered_set id_tasks = {id_task};
+ return recv(id_tasks);
}
// Send a new result to a waiting id_task
- void send(server_task_result result)
- {
- LOG_VERBOSE("send new result", {{"id_task", result.id}});
+ void send(server_task_result_ptr &&result) {
+ SRV_DBG("sending result for task id = %d\n", result->id);
std::unique_lock lock(mutex_results);
- for (const auto &id_task : waiting_task_ids)
- {
- // LOG_TEE("waiting task id %i \n", id_task);
- // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
- if (result.id_multi == id_task)
- {
- LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
- callback_update_multitask(id_task, result.id, result);
- continue;
- }
+ for (const auto &id_task : waiting_task_ids) {
+ if (result->id == id_task) {
+ SRV_DBG("task id = %d pushed to result queue\n", result->id);
- if (result.id == id_task)
- {
- LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
- queue_results.push_back(result);
+ queue_results.emplace_back(std::move(result));
condition_results.notify_all();
return;
}
@@ -660,26 +1703,30 @@ struct server_response
}
};
-struct server_context
-{
+struct server_context {
+ common_params params_base;
+
+ // note: keep these alive - they determine the lifetime of the model, context, etc.
+ common_init_result llama_init;
+ common_init_result llama_init_dft;
+
llama_model *model = nullptr;
llama_context *ctx = nullptr;
- gpt_params params;
+ const llama_vocab *vocab = nullptr;
- llama_batch batch;
+ llama_model *model_dft = nullptr;
+
+ llama_context_params cparams_dft;
+
+ llama_batch batch = {};
bool clean_kv_cache = true;
bool add_bos_token = true;
+ bool has_eos_token = false;
int32_t n_ctx; // total context for all clients / slots
- // system prompt
- bool system_need_update = false;
-
- std::string system_prompt;
- std::vector system_tokens;
-
// slots / clients
std::vector slots;
json default_generation_settings_for_props;
@@ -692,1217 +1739,851 @@ struct server_context
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
- ~server_context()
- {
- if (ctx)
- {
- llama_free(ctx);
- ctx = nullptr;
- }
-
- if (model)
- {
- llama_free_model(model);
- model = nullptr;
- }
+ common_chat_templates_ptr chat_templates;
+ ~server_context() {
// Clear any sampling context
- for (server_slot &slot : slots)
- {
- if (slot.ctx_sampling != nullptr)
- {
- llama_sampling_free(slot.ctx_sampling);
- }
- }
-
- llama_batch_free(batch);
- }
-
- bool load_model(const gpt_params ¶ms_)
- {
- params = params_;
-
- // dedicate one sequence to the system prompt
- params.n_parallel += 1;
-
- llama_init_result llama_init = llama_init_from_gpt_params(params);
-
- model = llama_init.model;
- ctx = llama_init.context;
- params.n_parallel -= 1; // but be sneaky about it
- if (model == nullptr)
- {
- LOG_ERROR("unable to load model", {{"model", params.model}});
- return false;
- }
-
- n_ctx = llama_n_ctx(ctx);
-
- add_bos_token = llama_should_add_bos_token(model);
- GGML_ASSERT(llama_add_eos_token(model) != 1);
-
- return true;
- }
-
- bool validate_model_chat_template() const
- {
- llama_chat_message chat[] = {{"user", "test"}};
-
- const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
-
- return res > 0;
- }
-
- void init()
- {
- const int32_t n_ctx_slot = n_ctx / params.n_parallel;
-
- LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
-
- for (int i = 0; i < params.n_parallel; i++)
- {
- server_slot slot;
-
- slot.id = i;
- slot.n_ctx = n_ctx_slot;
- slot.n_predict = params.n_predict;
-
- LOG_INFO("new slot", {{"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx}});
-
- const int ga_n = params.grp_attn_n;
- const int ga_w = params.grp_attn_w;
-
- if (ga_n != 1)
- {
- GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
- GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
- // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
- // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
-
- LOG_INFO("slot self-extend", {{"id_slot", slot.id}, {"ga_n", ga_n}, {"ga_w", ga_w}});
- }
-
- slot.ga_i = 0;
- slot.ga_n = ga_n;
- slot.ga_w = ga_w;
-
- slot.sparams = params.sparams;
-
- slot.reset();
-
- slots.push_back(slot);
- }
+ for (server_slot &slot : slots) {
+ common_sampler_free(slot.smpl);
+ slot.smpl = nullptr;
- default_generation_settings_for_props = get_formated_generation(slots.front());
- default_generation_settings_for_props["seed"] = -1;
+ llama_free(slot.ctx_dft);
+ slot.ctx_dft = nullptr;
- // the update_slots() logic will always submit a maximum of n_batch tokens
- // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not
- // used)
- {
- const int32_t n_batch = llama_n_batch(ctx);
+ common_speculative_free(slot.spec);
+ slot.spec = nullptr;
- // only a single seq_id per token is needed
- batch = llama_batch_init(n_batch, 0, 1);
+ llama_batch_free(slot.batch_spec);
}
- metrics.init();
+ llama_batch_free(batch);
}
- std::vector tokenize(const json &json_prompt, bool add_special) const
- {
- // TODO: currently, we tokenize using special tokens by default
- // this is not always correct (see
- // https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) but it's better compared to
- // completely ignoring ChatML and other chat templates
- const bool TMP_FORCE_SPECIAL = true;
-
- // If `add_bos` is true, we only add BOS, when json_prompt is a string,
- // or the first element of the json_prompt array is a string.
- std::vector prompt_tokens;
-
- if (json_prompt.is_array())
- {
- bool first = true;
- for (const auto &p : json_prompt)
- {
- if (p.is_string())
- {
- auto s = p.template get();
-
- std::vector p;
- if (first)
- {
- p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
- first = false;
- }
- else
- {
- p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
- }
+ bool load_model(const common_params ¶ms) {
+ SRV_INF("loading model '%s'\n", params.model.c_str());
- prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
- }
- else
- {
- if (first)
- {
- first = false;
- }
+ params_base = params;
- prompt_tokens.push_back(p.template get());
- }
- }
- }
- else
- {
- auto s = json_prompt.template get();
- prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
- }
+ llama_init = common_init_from_params(params_base);
- return prompt_tokens;
- }
+ model = llama_init.model.get();
+ ctx = llama_init.context.get();
- server_slot *get_slot_by_id(int id)
- {
- for (server_slot &slot : slots)
- {
- if (slot.id == id)
- {
- return &slot;
- }
+ if (model == nullptr) {
+ SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
+ return false;
}
- return nullptr;
- }
-
- server_slot *get_available_slot(const std::string &prompt)
- {
- server_slot *ret = nullptr;
-
- // find the slot that has at least n% prompt similarity
- if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty())
- {
- int max_lcp_len = 0;
- float similarity = 0;
-
- for (server_slot &slot : slots)
- {
- // skip the slot if it is not available
- if (!slot.available())
- {
- continue;
- }
-
- // skip the slot if it does not contains prompt
- if (!slot.prompt.is_string())
- {
- continue;
- }
-
- // current slot's prompt
- std::string slot_prompt = slot.prompt.get();
-
- // length of the current slot's prompt
- int slot_prompt_len = slot_prompt.size();
-
- // length of the Longest Common Prefix between the current slot's prompt and the input prompt
- int lcp_len = common_part(slot_prompt, prompt);
-
- // fraction of the common substring length compared to the current slot's prompt length
- similarity = static_cast(lcp_len) / slot_prompt_len;
-
- // select the current slot if the criteria match
- if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity)
- {
- max_lcp_len = lcp_len;
- ret = &slot;
- }
- }
-
- if (ret != nullptr)
- {
- LOG_VERBOSE("selected slot by lcp similarity", {
- {"id_slot", ret->id},
- {"max_lcp_len", max_lcp_len},
- {"similarity", similarity},
- });
- }
- }
+ vocab = llama_model_get_vocab(model);
- // find the slot that has been least recently used
- if (ret == nullptr)
- {
- int64_t t_last = ggml_time_us();
- for (server_slot &slot : slots)
- {
- // skip the slot if it is not available
- if (!slot.available())
- {
- continue;
- }
+ n_ctx = llama_n_ctx(ctx);
- // select the current slot if the criteria match
- if (slot.t_last_used < t_last)
- {
- t_last = slot.t_last_used;
- ret = &slot;
- }
- }
+ add_bos_token = llama_vocab_get_add_bos(vocab);
+ has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
- if (ret != nullptr)
- {
- LOG_VERBOSE("selected slot by lru", {
- {"id_slot", ret->id},
- {"t_last", t_last},
- });
- }
- }
+ if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
+ SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
- return ret;
- }
+ auto params_dft = params_base;
- bool launch_slot_with_task(server_slot &slot, const server_task &task)
- {
- slot_params default_params;
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still
- // override them)
- llama_sampling_params default_sparams = params.sparams;
- auto &data = task.data;
-
- slot.oaicompat = false;
- slot.oaicompat_model = "";
-
- slot.params.stream = json_value(data, "stream", false);
- slot.params.cache_prompt = json_value(data, "cache_prompt", false);
- slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
- slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
- slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
- slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
- slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
- slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
- slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
- slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
- slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
- slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
- slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
- slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
- slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
- slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
- slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
- slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
- slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
- slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
- slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
- slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
-
- if (slot.params.cache_prompt && slot.ga_n != 1)
- {
- LOG_WARNING("cache_prompt is not supported with group-attention", {});
- slot.params.cache_prompt = false;
- }
+ params_dft.devices = params_base.speculative.devices;
+ params_dft.hf_file = params_base.speculative.hf_file;
+ params_dft.hf_repo = params_base.speculative.hf_repo;
+ params_dft.model = params_base.speculative.model;
+ params_dft.model_url = params_base.speculative.model_url;
+ params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel
+ : params_base.speculative.n_ctx;
+ params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
+ params_dft.n_parallel = 1;
- if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict)
- {
- // Might be better to reject the request with a 400 ?
- LOG_WARNING("Max tokens to predict exceeds server configuration",
- {
- {"params.n_predict", slot.params.n_predict},
- {"slot.n_predict", slot.n_predict},
- });
- slot.params.n_predict = slot.n_predict;
- }
+ llama_init_dft = common_init_from_params(params_dft);
- // infill
- slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
- slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
+ model_dft = llama_init_dft.model.get();
- // get prompt
- if (!task.infill)
- {
- const auto &prompt = data.find("prompt");
- if (prompt == data.end())
- {
- send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
+ if (model_dft == nullptr) {
+ SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
return false;
}
- if ((prompt->is_string()) || (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
- (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer()))
- {
- slot.prompt = *prompt;
- }
- else
- {
- send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
+ if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
+ params_base.speculative.model.c_str(), params_base.model.c_str());
+
return false;
}
- }
-
- // penalize user-provided tokens
- {
- slot.sparams.penalty_prompt_tokens.clear();
- slot.sparams.use_penalty_prompt_tokens = false;
- const auto &penalty_prompt = data.find("penalty_prompt");
+ const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
- if (penalty_prompt != data.end())
- {
- if (penalty_prompt->is_string())
- {
- const auto penalty_prompt_string = penalty_prompt->get();
- slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
-
- if (slot.params.n_predict > 0)
- {
- slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() +
- slot.params.n_predict);
- }
- slot.sparams.use_penalty_prompt_tokens = true;
+ cparams_dft = common_context_params_to_llama(params_dft);
+ cparams_dft.n_batch = n_ctx_dft;
- LOG_VERBOSE("penalty_prompt_tokens", {
- {"id_slot", slot.id},
- {"tokens", slot.sparams.penalty_prompt_tokens},
- });
- }
- else if (penalty_prompt->is_array())
- {
- const auto n_tokens = penalty_prompt->size();
- slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
-
- const int n_vocab = llama_n_vocab(model);
- for (const auto &penalty_token : *penalty_prompt)
- {
- if (penalty_token.is_number_integer())
- {
- const auto tok = penalty_token.get();
- if (tok >= 0 && tok < n_vocab)
- {
- slot.sparams.penalty_prompt_tokens.push_back(tok);
- }
- }
- }
- slot.sparams.use_penalty_prompt_tokens = true;
+ // force F16 KV cache for the draft model for extra performance
+ cparams_dft.type_k = GGML_TYPE_F16;
+ cparams_dft.type_v = GGML_TYPE_F16;
- LOG_VERBOSE("penalty_prompt_tokens", {
- {"id_slot", slot.id},
- {"tokens", slot.sparams.penalty_prompt_tokens},
- });
- }
- }
+ // the context is not needed - we will create one for each slot
+ llama_init_dft.context.reset();
}
- {
- slot.sparams.logit_bias.clear();
-
- if (json_value(data, "ignore_eos", false))
- {
- slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
- }
-
- const auto &logit_bias = data.find("logit_bias");
- if (logit_bias != data.end() && logit_bias->is_array())
- {
- const int n_vocab = llama_n_vocab(model);
- for (const auto &el : *logit_bias)
- {
- // TODO: we may want to throw errors here, in case "el" is incorrect
- if (el.is_array() && el.size() == 2)
- {
- float bias;
- if (el[1].is_number())
- {
- bias = el[1].get();
- }
- else if (el[1].is_boolean() && !el[1].get())
- {
- bias = -INFINITY;
- }
- else
- {
- continue;
- }
-
- if (el[0].is_number_integer())
- {
- llama_token tok = el[0].get();
- if (tok >= 0 && tok < n_vocab)
- {
- slot.sparams.logit_bias[tok] = bias;
- }
- }
- else if (el[0].is_string())
- {
- auto toks = llama_tokenize(model, el[0].get(), false);
- for (auto tok : toks)
- {
- slot.sparams.logit_bias[tok] = bias;
- }
- }
- }
- }
- }
+ chat_templates = common_chat_templates_init(model, params_base.chat_template);
+ try {
+ common_chat_format_example(chat_templates.get(), params.use_jinja);
+ } catch (const std::exception &e) {
+ SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. "
+ "This may cause the model to output suboptimal responses\n",
+ __func__);
+ chat_templates = common_chat_templates_init(model, "chatml");
}
- {
- slot.params.antiprompt.clear();
+ return true;
+ }
- const auto &stop = data.find("stop");
- if (stop != data.end() && stop->is_array())
- {
- for (const auto &word : *stop)
- {
- if (!word.empty())
- {
- slot.params.antiprompt.push_back(word);
- }
+ void init() {
+ const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
+
+ SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
+
+ for (int i = 0; i < params_base.n_parallel; i++) {
+ server_slot slot;
+
+ slot.id = i;
+ slot.ctx = ctx;
+ slot.n_ctx = n_ctx_slot;
+ slot.n_predict = params_base.n_predict;
+
+ if (model_dft) {
+ slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
+
+ slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
+ if (slot.ctx_dft == nullptr) {
+ SRV_ERR("%s", "failed to create draft context\n");
+ return;
}
- }
- }
- {
- const auto &samplers_sequence = data.find("samplers");
- if (samplers_sequence != data.end() && samplers_sequence->is_array())
- {
- std::vector sampler_names;
- for (const auto &sampler_name : *samplers_sequence)
- {
- if (sampler_name.is_string())
- {
- sampler_names.emplace_back(sampler_name);
- }
+ slot.spec = common_speculative_init(slot.ctx_dft);
+ if (slot.spec == nullptr) {
+ SRV_ERR("%s", "failed to create speculator\n");
+ return;
}
- slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
- }
- else
- {
- slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
}
+
+ SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
+
+ slot.params.sampling = params_base.sampling;
+
+ slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); };
+
+ slot.reset();
+
+ slots.push_back(slot);
}
+ default_generation_settings_for_props = slots[0].to_json();
+
+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
+ // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not
+ // used)
{
- if (slot.ctx_sampling != nullptr)
- {
- llama_sampling_free(slot.ctx_sampling);
- }
- slot.ctx_sampling = llama_sampling_init(slot.sparams);
- if (slot.ctx_sampling == nullptr)
- {
- // for now, the only error that may happen here is invalid grammar
- send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
- return false;
- }
+ const int32_t n_batch = llama_n_batch(ctx);
+
+ // only a single seq_id per token is needed
+ batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
- slot.command = SLOT_COMMAND_LOAD_PROMPT;
- slot.prompt_tokens.clear();
+ metrics.init();
+ }
- LOG_INFO("slot is processing task", {
- {"id_slot", slot.id},
- {"id_task", slot.id_task},
- });
+ server_slot *get_slot_by_id(int id) {
+ for (server_slot &slot : slots) {
+ if (slot.id == id) {
+ return &slot;
+ }
+ }
- return true;
+ return nullptr;
}
- void kv_cache_clear()
- {
- LOG_VERBOSE("clearing KV cache", {});
+ server_slot *get_available_slot(const server_task &task) {
+ server_slot *ret = nullptr;
- // clear the entire KV cache
- llama_kv_cache_clear(ctx);
- clean_kv_cache = false;
- }
+ // find the slot that has at least n% prompt similarity
+ if (ret == nullptr && slot_prompt_similarity != 0.0f) {
+ int lcs_len = 0;
+ float similarity = 0;
+
+ for (server_slot &slot : slots) {
+ // skip the slot if it is not available
+ if (slot.is_processing()) {
+ continue;
+ }
- void system_prompt_update()
- {
- LOG_VERBOSE("system prompt update", {
- {"system_prompt", system_prompt},
- });
+ // skip the slot if it does not contains cached tokens
+ if (slot.cache_tokens.empty()) {
+ continue;
+ }
- kv_cache_clear();
- system_tokens.clear();
+ // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
+ int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
- if (!system_prompt.empty())
- {
- system_tokens = ::llama_tokenize(ctx, system_prompt, true);
+ // fraction of the common subsequence length compared to the current slot's prompt length
+ float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size());
- llama_batch_clear(batch);
+ // select the current slot if the criteria match
+ if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
+ lcs_len = cur_lcs_len;
+ similarity = cur_similarity;
+ ret = &slot;
+ }
+ }
- for (int i = 0; i < (int)system_tokens.size(); ++i)
- {
- llama_batch_add(batch, system_tokens[i], i, {0}, false);
+ if (ret != nullptr) {
+ SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
}
+ }
- const int32_t n_batch = llama_n_batch(ctx);
+ // find the slot that has been least recently used
+ if (ret == nullptr) {
+ int64_t t_last = ggml_time_us();
+ for (server_slot &slot : slots) {
+ // skip the slot if it is not available
+ if (slot.is_processing()) {
+ continue;
+ }
- for (int32_t i = 0; i < batch.n_tokens; i += n_batch)
- {
- const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
- llama_batch batch_view = {
- n_tokens,
- batch.token + i,
- nullptr,
- batch.pos + i,
- batch.n_seq_id + i,
- batch.seq_id + i,
- batch.logits + i,
- 0,
- 0,
- 0, // unused
- };
-
- if (llama_decode(ctx, batch_view) != 0)
- {
- LOG_ERROR("llama_decode() failed", {});
- return;
+ // select the current slot if the criteria match
+ if (slot.t_last_used < t_last) {
+ t_last = slot.t_last_used;
+ ret = &slot;
}
}
- // assign the system KV cache to all parallel sequences
- for (int32_t i = 1; i <= params.n_parallel; ++i)
- {
- llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+ if (ret != nullptr) {
+ SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last);
}
}
- system_need_update = false;
+ return ret;
}
- bool system_prompt_set(const std::string &sys_prompt)
- {
- system_prompt = sys_prompt;
+ bool launch_slot_with_task(server_slot &slot, const server_task &task) {
+ slot.reset();
+ slot.id_task = task.id;
+ slot.index = task.index;
+ slot.task_type = task.type;
+ slot.params = std::move(task.params);
+ slot.prompt_tokens = std::move(task.prompt_tokens);
+
+ if (!are_lora_equal(task.params.lora, slot.lora)) {
+ // if lora is changed, we cannot reuse cached tokens
+ slot.cache_tokens.clear();
+ slot.lora = task.params.lora;
+ }
+
+ SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
+
+ if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
+ // Might be better to reject the request with a 400 ?
+ SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict,
+ slot.n_predict);
+ slot.params.n_predict = slot.n_predict;
+ }
- LOG_VERBOSE("system prompt process", {
- {"system_prompt", system_prompt},
- });
+ if (slot.params.ignore_eos && has_eos_token) {
+ slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
+ }
- // release all slots
- for (server_slot &slot : slots)
{
- slot.release();
+ if (slot.smpl != nullptr) {
+ common_sampler_free(slot.smpl);
+ }
+
+ slot.smpl = common_sampler_init(model, slot.params.sampling);
+ if (slot.smpl == nullptr) {
+ // for now, the only error that may happen here is invalid grammar
+ send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
}
- system_need_update = true;
+ if (slot.ctx_dft) {
+ llama_batch_free(slot.batch_spec);
+
+ slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
+ }
+
+ slot.state = SLOT_STATE_STARTED;
+
+ SLT_INF(slot, "%s", "processing task\n");
+
return true;
}
- bool process_token(completion_token_output &result, server_slot &slot)
- {
+ void kv_cache_clear() {
+ SRV_DBG("%s", "clearing KV cache\n");
+
+ // clear the entire KV cache
+ llama_kv_cache_clear(ctx);
+ clean_kv_cache = false;
+ }
+
+ bool process_token(completion_token_output &result, server_slot &slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
- const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
+ const std::string token_str = result.text_to_send;
slot.sampled = result.tok;
- // search stop word and delete it
slot.generated_text += token_str;
- slot.has_next_token = true;
-
- if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
- {
- // we can change penalty_prompt_tokens because it is always created from scratch each request
- slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
+ if (slot.params.return_tokens) {
+ slot.generated_tokens.push_back(result.tok);
}
+ slot.has_next_token = true;
// check if there is incomplete UTF-8 character at the end
- bool incomplete = false;
- for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
- {
- unsigned char c = slot.generated_text[slot.generated_text.size() - i];
- if ((c & 0xC0) == 0x80)
- {
- // continuation byte: 10xxxxxx
- continue;
- }
- if ((c & 0xE0) == 0xC0)
- {
- // 2-byte character: 110xxxxx ...
- incomplete = i < 2;
- }
- else if ((c & 0xF0) == 0xE0)
- {
- // 3-byte character: 1110xxxx ...
- incomplete = i < 3;
- }
- else if ((c & 0xF8) == 0xF0)
- {
- // 4-byte character: 11110xxx ...
- incomplete = i < 4;
- }
- // else 1-byte character or invalid byte
- break;
- }
+ bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
- if (!incomplete)
- {
+ // search stop word and delete it
+ if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
- bool is_stop_full = false;
+ bool send_text = true;
- size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
- if (stop_pos != std::string::npos)
- {
- is_stop_full = true;
+ size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
+ if (stop_pos != std::string::npos) {
slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
- }
- else
- {
- is_stop_full = false;
- stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
+ } else if (slot.has_next_token) {
+ stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
+ send_text = stop_pos == std::string::npos;
}
// check if there is any token to predict
- if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0))
- {
+ if (send_text) {
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size();
// add the token to slot queue and cache
+ } else {
+ result.text_to_send = "";
}
- slot.add_token_string(result);
- if (slot.params.stream)
- {
+ slot.add_token(result);
+ if (slot.params.stream) {
send_partial_response(slot, result);
}
}
- if (incomplete)
- {
+ if (incomplete) {
slot.has_next_token = true;
}
// check the limits
- if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params))
- {
- slot.stopped_limit = true;
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- LOG_VERBOSE("stopped by limit", {
- {"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"n_decoded", slot.n_decoded},
- {"n_predict", slot.params.n_predict},
- });
+ SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
- if (llama_token_is_eog(model, result.tok))
- {
- slot.stopped_eos = true;
+ if (slot.has_new_line) {
+ // if we have already seen a new line, we stop after a certain time limit
+ if (slot.params.t_max_predict_ms > 0 &&
+ (ggml_time_us() - slot.t_start_generation > 1000.0f * slot.params.t_max_predict_ms)) {
+ slot.stop = STOP_TYPE_LIMIT;
+ slot.has_next_token = false;
+
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded,
+ (int)slot.params.t_max_predict_ms);
+ }
+
+ // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
+ if (slot.params.n_indent > 0) {
+ // check the current indentation
+ // TODO: improve by not doing it more than once for each new line
+ if (slot.last_nl_pos > 0) {
+ size_t pos = slot.last_nl_pos;
+
+ int n_indent = 0;
+ while (pos < slot.generated_text.size() &&
+ (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
+ n_indent++;
+ pos++;
+ }
+
+ if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
+ slot.stop = STOP_TYPE_LIMIT;
+ slot.has_next_token = false;
+
+ // cut the last line
+ slot.generated_text.erase(pos, std::string::npos);
+
+ SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded,
+ n_indent);
+ }
+ }
+
+ // find the next new line
+ {
+ const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
+
+ if (pos != std::string::npos) {
+ slot.last_nl_pos = pos + 1;
+ }
+ }
+ }
+ }
+
+ // check if there is a new line in the generated text
+ if (result.text_to_send.find('\n') != std::string::npos) {
+ slot.has_new_line = true;
+ }
+
+ // if context shift is disabled, we stop when it reaches the context limit
+ if (slot.n_past >= slot.n_ctx) {
+ slot.truncated = true;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- LOG_VERBOSE("eos token found", {});
+ SLT_DBG(slot,
+ "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = "
+ "%d, n_ctx = %d\n",
+ slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
}
- auto n_ctx_train = llama_n_ctx_train(model);
- if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 &&
- slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train)
- {
- LOG_WARNING("n_predict is not set and self-context extend is disabled."
- " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop",
- {
- {"id_slot", slot.id},
- {"params.n_predict", slot.params.n_predict},
- {"slot.n_prompt_tokens", slot.n_prompt_tokens},
- {"slot.n_decoded", slot.n_decoded},
- {"slot.n_predict", slot.n_predict},
- {"n_slots", params.n_parallel},
- {"slot.n_ctx", slot.n_ctx},
- {"n_ctx", n_ctx},
- {"n_ctx_train", n_ctx_train},
- {"ga_n", slot.ga_n},
- });
+ if (llama_vocab_is_eog(vocab, result.tok)) {
+ slot.stop = STOP_TYPE_EOS;
+ slot.has_next_token = false;
+
+ SLT_DBG(slot, "%s", "stopped by EOS\n");
+ }
+
+ const auto n_ctx_train = llama_model_n_ctx_train(model);
+
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
- slot.stopped_limit = true;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction
+
+ SLT_WRN(slot,
+ "n_predict (%d) is set for infinite generation. "
+ "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
+ slot.params.n_predict, n_ctx_train);
}
- LOG_VERBOSE("next token", {
- {"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"token", result.tok},
- {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
- {"has_next_token", slot.has_next_token},
- {"n_remain", slot.n_remaining},
- {"n_decoded", slot.n_decoded},
- {"stopped_eos", slot.stopped_eos},
- {"stopped_word", slot.stopped_word},
- {"stopped_limit", slot.stopped_limit},
- {"stopping_word", slot.stopping_word},
- });
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining,
+ result.tok, token_str.c_str());
return slot.has_next_token; // continue
}
- json get_formated_generation(const server_slot &slot) const
- {
- const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
- const bool ignore_eos =
- eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
+ void populate_token_probs(const server_slot &slot, completion_token_output &result, bool post_sampling,
+ bool special, int idx) {
+ size_t n_probs = slot.params.sampling.n_probs;
+ size_t n_vocab = llama_vocab_n_tokens(vocab);
+ if (post_sampling) {
+ const auto *cur_p = common_sampler_get_candidates(slot.smpl);
+ const size_t max_probs = cur_p->size;
+
+ // set probability for sampled token
+ for (size_t i = 0; i < max_probs; i++) {
+ if (cur_p->data[i].id == result.tok) {
+ result.prob = cur_p->data[i].p;
+ break;
+ }
+ }
- std::vector samplers_sequence;
- samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
- for (const auto &sampler_type : slot.sparams.samplers_sequence)
- {
- samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
- }
-
- return json{{"n_ctx", slot.n_ctx},
- {"n_predict", slot.n_predict},
- {"model", params.model_alias},
- {"seed", slot.sparams.seed},
- {"temperature", slot.sparams.temp},
- {"dynatemp_range", slot.sparams.dynatemp_range},
- {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
- {"top_k", slot.sparams.top_k},
- {"top_p", slot.sparams.top_p},
- {"min_p", slot.sparams.min_p},
- {"tfs_z", slot.sparams.tfs_z},
- {"typical_p", slot.sparams.typical_p},
- {"repeat_last_n", slot.sparams.penalty_last_n},
- {"repeat_penalty", slot.sparams.penalty_repeat},
- {"presence_penalty", slot.sparams.penalty_present},
- {"frequency_penalty", slot.sparams.penalty_freq},
- {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
- {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
- {"mirostat", slot.sparams.mirostat},
- {"mirostat_tau", slot.sparams.mirostat_tau},
- {"mirostat_eta", slot.sparams.mirostat_eta},
- {"penalize_nl", slot.sparams.penalize_nl},
- {"stop", slot.params.antiprompt},
- {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
- {"n_keep", slot.params.n_keep},
- {"n_discard", slot.params.n_discard},
- {"ignore_eos", ignore_eos},
- {"stream", slot.params.stream},
- {"logit_bias", slot.sparams.logit_bias},
- {"n_probs", slot.sparams.n_probs},
- {"min_keep", slot.sparams.min_keep},
- {"grammar", slot.sparams.grammar},
- {"samplers", samplers_sequence}};
- }
-
- void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER)
- {
- send_error(task.id, task.id_multi, error, type);
- }
-
- void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER)
- {
- send_error(slot.id_task, slot.id_multi, error, type);
- }
-
- void send_error(const int id_task, const int id_multi, const std::string &error,
- const enum error_type type = ERROR_TYPE_SERVER)
- {
- LOG_ERROR("task error", {
- {"id_multi", id_multi},
- {"id_task", id_task},
- {"error", error},
- });
-
- server_task_result res;
- res.id = id_task;
- res.id_multi = id_multi;
- res.stop = false;
- res.error = true;
- res.data = format_error_response(error, type);
-
- queue_results.send(res);
- }
-
- void send_partial_response(server_slot &slot, completion_token_output tkn)
- {
- server_task_result res;
- res.id = slot.id_task;
- res.id_multi = slot.id_multi;
- res.error = false;
- res.stop = false;
- res.data = json{{"content", tkn.text_to_send}, {"stop", false}, {"id_slot", slot.id}, {"multimodal", false}};
-
- if (slot.sparams.n_probs > 0)
- {
- const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
- const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
- const size_t probs_stop_pos =
- std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
+ // set probability for top n_probs tokens
+ result.probs.reserve(max_probs);
+ for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
+ result.probs.push_back(
+ {cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), cur_p->data[i].p});
+ }
+ } else {
+ // TODO: optimize this with min-p optimization
+ std::vector cur = get_token_probabilities(ctx, idx);
- std::vector probs_output;
- if (probs_pos < probs_stop_pos)
- {
- probs_output =
- std::vector(slot.generated_token_probs.begin() + probs_pos,
- slot.generated_token_probs.begin() + probs_stop_pos);
+ // set probability for sampled token
+ for (size_t i = 0; i < n_vocab; i++) {
+ // set probability for sampled token
+ if (cur[i].id == result.tok) {
+ result.prob = cur[i].p;
+ break;
+ }
}
- slot.n_sent_token_probs = probs_stop_pos;
- res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
+ // set probability for top n_probs tokens
+ result.probs.reserve(n_probs);
+ for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
+ result.probs.push_back({cur[i].id, common_token_to_piece(ctx, cur[i].id, special), cur[i].p});
+ }
}
+ }
- if (slot.oaicompat)
- {
- res.data["oaicompat_token_ctr"] = slot.n_decoded;
- res.data["model"] = slot.oaicompat_model;
- }
-
- queue_results.send(res);
- }
-
- void send_final_response(const server_slot &slot)
- {
- server_task_result res;
- res.id = slot.id_task;
- res.id_multi = slot.id_multi;
- res.error = false;
- res.stop = true;
- res.data = json{{"content", !slot.params.stream ? slot.generated_text : ""},
- {"id_slot", slot.id},
- {"stop", true},
- {"model", params.model_alias},
- {"tokens_predicted", slot.n_decoded},
- {"tokens_evaluated", slot.n_prompt_tokens},
- {"generation_settings", get_formated_generation(slot)},
- {"prompt", slot.prompt},
- {"truncated", slot.truncated},
- {"stopped_eos", slot.stopped_eos},
- {"stopped_word", slot.stopped_word},
- {"stopped_limit", slot.stopped_limit},
- {"stopping_word", slot.stopping_word},
- {"tokens_cached", slot.n_past},
- {"timings", slot.get_formated_timings()}};
-
- if (slot.sparams.n_probs > 0)
- {
- std::vector probs;
- if (!slot.params.stream && slot.stopped_word)
- {
- const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
+ void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ send_error(task.id, error, type);
+ }
- size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
- probs = std::vector(slot.generated_token_probs.begin(),
- slot.generated_token_probs.end() - safe_offset);
- }
- else
- {
- probs = std::vector(slot.generated_token_probs.begin(),
- slot.generated_token_probs.end());
- }
+ void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ send_error(slot.id_task, error, type);
+ }
+
+ void send_error(const int id_task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
+
+ auto res = std::make_unique();
+ res->id = id_task;
+ res->err_type = type;
+ res->err_msg = error;
+
+ queue_results.send(std::move(res));
+ }
+
+ void send_partial_response(server_slot &slot, const completion_token_output &tkn) {
+ auto res = std::make_unique();
+
+ res->id = slot.id_task;
+ res->index = slot.index;
+ res->content = tkn.text_to_send;
+ res->tokens = {tkn.tok};
+
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->post_sampling_probs = slot.params.post_sampling_probs;
- res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
+ res->verbose = slot.params.verbose;
+ res->oaicompat = slot.params.oaicompat;
+ res->oaicompat_model = slot.params.oaicompat_model;
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+
+ // populate res.probs_output
+ if (slot.params.sampling.n_probs > 0) {
+ res->prob_output = tkn; // copy the token probs
}
- if (slot.oaicompat)
- {
- res.data["oaicompat_token_ctr"] = slot.n_decoded;
- res.data["model"] = slot.oaicompat_model;
+ // populate timings if this is final response or timings_per_token is enabled
+ if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
+ res->timings = slot.get_timings();
+ }
+
+ queue_results.send(std::move(res));
+ }
+
+ void send_final_response(server_slot &slot) {
+ auto res = std::make_unique();
+ res->id = slot.id_task;
+ res->id_slot = slot.id;
+
+ res->index = slot.index;
+ res->content = std::move(slot.generated_text);
+ res->tokens = std::move(slot.generated_tokens);
+ res->timings = slot.get_timings();
+ res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
+ res->response_fields = std::move(slot.params.response_fields);
+
+ res->truncated = slot.truncated;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->n_tokens_cached = slot.n_past;
+ res->has_new_line = slot.has_new_line;
+ res->stopping_word = slot.stopping_word;
+ res->stop = slot.stop;
+ res->post_sampling_probs = slot.params.post_sampling_probs;
+
+ res->verbose = slot.params.verbose;
+ res->stream = slot.params.stream;
+ res->oaicompat = slot.params.oaicompat;
+ res->oaicompat_model = slot.params.oaicompat_model;
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+ res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
+ // populate res.probs_output
+ if (slot.params.sampling.n_probs > 0) {
+ if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
+ const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
+
+ size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
+ res->probs_output = std::vector(
+ slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset);
+ } else {
+ res->probs_output = std::vector(slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end());
+ }
}
- queue_results.send(res);
+ res->generation_params = slot.params; // copy the parameters
+
+ queue_results.send(std::move(res));
}
- void send_embedding(const server_slot &slot, const llama_batch &batch)
- {
- server_task_result res;
- res.id = slot.id_task;
- res.id_multi = slot.id_multi;
- res.error = false;
- res.stop = true;
+ void send_embedding(const server_slot &slot, const llama_batch &batch) {
+ auto res = std::make_unique();
+ res->id = slot.id_task;
+ res->index = slot.index;
+ res->n_tokens = slot.n_prompt_tokens;
+ res->oaicompat = slot.params.oaicompat;
- const int n_embd = llama_n_embd(model);
+ const int n_embd = llama_model_n_embd(model);
std::vector embd_res(n_embd, 0.0f);
- for (int i = 0; i < batch.n_tokens; ++i)
- {
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1)
- {
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
- if (embd == NULL)
- {
+ if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
- if (embd == NULL)
- {
- LOG_ERROR("failed to get embeddings", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
-
- res.data = json{
- {"embedding", std::vector(n_embd, 0.0f)},
- };
+ if (embd == NULL) {
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i],
+ batch.seq_id[i][0]);
+ res->embedding.push_back(std::vector(n_embd, 0.0f));
continue;
}
- llama_embd_normalize(embd, embd_res.data(), n_embd);
-
- res.data = json{
- {"embedding", embd_res},
- };
+ // normalize only when there is pooling
+ // TODO: configurable
+ if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
+ common_embd_normalize(embd, embd_res.data(), n_embd, 2);
+ res->embedding.push_back(embd_res);
+ } else {
+ res->embedding.push_back({embd, embd + n_embd});
+ }
}
- queue_results.send(res);
+ SLT_DBG(slot, "%s", "sending embeddings\n");
+
+ queue_results.send(std::move(res));
}
- void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding)
- {
- server_task task;
- task.id = id_task;
- task.id_multi = id_multi;
- task.id_target = 0;
- task.data = std::move(data);
- task.infill = infill;
- task.embedding = embedding;
- task.type = SERVER_TASK_TYPE_COMPLETION;
+ void send_rerank(const server_slot &slot, const llama_batch &batch) {
+ auto res = std::make_unique();
+ res->id = slot.id_task;
+ res->index = slot.index;
+ res->n_tokens = slot.n_prompt_tokens;
- // when a completion task's prompt array is not a singleton, we split it into multiple requests
- // otherwise, it's a single-prompt task, we actually queue it
- // if there's numbers in the prompt array it will be treated as an array of tokens
- if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1)
- {
- bool numbers = false;
- for (const auto &e : task.data.at("prompt"))
- {
- if (e.is_number())
- {
- numbers = true;
- break;
- }
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
+ continue;
}
- // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
- // it will completely stall the server. I don't know where the bug for this is.
- //
- // if there are numbers, it needs to be treated like a single prompt,
- // queue_tasks handles a mix of strings and numbers just fine.
- if (numbers)
- {
- queue_tasks.post(task);
+ const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ if (embd == NULL) {
+ embd = llama_get_embeddings_ith(ctx, i);
}
- else
- {
- split_multiprompt_task(id_task, task);
+
+ if (embd == NULL) {
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i],
+ batch.seq_id[i][0]);
+
+ res->score = -1e6;
+ continue;
}
+
+ res->score = embd[0];
}
- else
- {
- queue_tasks.post(task);
- }
- }
- void request_cancel(int id_task)
- {
- server_task task;
- task.type = SERVER_TASK_TYPE_CANCEL;
- task.id_target = id_task;
+ SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
- queue_tasks.post(task);
+ queue_results.send(std::move(res));
}
- void split_multiprompt_task(int id_multi, const server_task &multiprompt_task)
- {
- const int prompt_count = multiprompt_task.data.at("prompt").size();
- if (prompt_count <= 1)
- {
- send_error(multiprompt_task, "error while handling multiple prompts");
- return;
+ //
+ // Functions to create new task(s) and receive result(s)
+ //
+
+ void cancel_tasks(const std::unordered_set &id_tasks) {
+ std::vector cancel_tasks;
+ cancel_tasks.reserve(id_tasks.size());
+ for (const auto &id_task : id_tasks) {
+ SRV_WRN("cancel task, id_task = %d\n", id_task);
+
+ server_task task(SERVER_TASK_TYPE_CANCEL);
+ task.id_target = id_task;
+ queue_results.remove_waiting_task_id(id_task);
+ cancel_tasks.push_back(task);
}
+ // push to beginning of the queue, so it has highest priority
+ queue_tasks.post(cancel_tasks, true);
+ }
- // generate all the ID for subtask
- std::vector subtask_ids(prompt_count);
- for (int i = 0; i < prompt_count; i++)
- {
- subtask_ids[i] = queue_tasks.get_new_id();
+ // receive the results from task(s)
+ void receive_multi_results(const std::unordered_set &id_tasks,
+ const std::function &)> &result_handler,
+ const std::function &error_handler,
+ const std::function &is_connection_closed) {
+ std::vector results(id_tasks.size());
+ for (int i = 0; i < (int)id_tasks.size(); i++) {
+ server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
+
+ if (is_connection_closed()) {
+ cancel_tasks(id_tasks);
+ return;
+ }
+
+ if (result == nullptr) {
+ i--; // retry
+ continue;
+ }
+
+ if (result->is_error()) {
+ error_handler(result->to_json());
+ cancel_tasks(id_tasks);
+ return;
+ }
+
+ GGML_ASSERT(dynamic_cast(result.get()) != nullptr ||
+ dynamic_cast(result.get()) != nullptr ||
+ dynamic_cast(result.get()) != nullptr);
+ const size_t idx = result->get_index();
+ GGML_ASSERT(idx < results.size() && "index out of range");
+ results[idx] = std::move(result);
}
+ result_handler(results);
+ }
+
+ // receive the results from task(s), in stream mode
+ void receive_cmpl_results_stream(const std::unordered_set &id_tasks,
+ const std::function &result_handler,
+ const std::function &error_handler,
+ const std::function &is_connection_closed) {
+ size_t n_finished = 0;
+ while (true) {
+ server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
+
+ if (is_connection_closed()) {
+ cancel_tasks(id_tasks);
+ return;
+ }
- // queue up the multitask so we can track its subtask progression
- queue_tasks.add_multitask(id_multi, subtask_ids);
+ if (result == nullptr) {
+ continue; // retry
+ }
- // add subtasks
- for (int i = 0; i < prompt_count; i++)
- {
- json subtask_data = multiprompt_task.data;
- subtask_data["prompt"] = subtask_data.at("prompt")[i];
+ if (result->is_error()) {
+ error_handler(result->to_json());
+ cancel_tasks(id_tasks);
+ return;
+ }
+
+ GGML_ASSERT(dynamic_cast(result.get()) != nullptr ||
+ dynamic_cast(result.get()) != nullptr);
+ if (!result_handler(result)) {
+ cancel_tasks(id_tasks);
+ break;
+ }
- // subtasks inherit everything else (infill mode, embedding mode, etc.)
- request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill,
- multiprompt_task.embedding);
+ if (result->is_stop()) {
+ if (++n_finished == id_tasks.size()) {
+ break;
+ }
+ }
}
}
- void process_single_task(const server_task &task)
- {
- switch (task.type)
- {
- case SERVER_TASK_TYPE_COMPLETION: {
- const int id_slot = json_value(task.data, "id_slot", -1);
-
- server_slot *slot;
+ //
+ // Functions to process the task
+ //
- if (id_slot != -1)
- {
- slot = get_slot_by_id(id_slot);
- }
- else
- {
- std::string prompt;
- if (task.data.contains("prompt") && task.data.at("prompt").is_string())
- {
- prompt = json_value(task.data, "prompt", std::string());
- }
+ void process_single_task(server_task task) {
+ switch (task.type) {
+ case SERVER_TASK_TYPE_COMPLETION:
+ case SERVER_TASK_TYPE_INFILL:
+ case SERVER_TASK_TYPE_EMBEDDING:
+ case SERVER_TASK_TYPE_RERANK: {
+ const int id_slot = task.id_selected_slot;
- slot = get_available_slot(prompt);
- }
+ server_slot *slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
- if (slot == nullptr)
- {
+ if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
- LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
+ SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
queue_tasks.defer(task);
break;
}
- if (!slot->available())
- {
+ if (slot->is_processing()) {
// if requested slot is unavailable, we defer this task for processing later
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
queue_tasks.defer(task);
break;
}
- if (task.data.contains("system_prompt"))
- {
- std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
- system_prompt_set(sys_prompt);
-
- for (server_slot &slot : slots)
- {
- slot.n_past = 0;
- slot.n_past_se = 0;
- }
- }
-
- slot->reset();
-
- slot->id_task = task.id;
- slot->id_multi = task.id_multi;
- slot->infill = task.infill;
- slot->embedding = task.embedding;
-
- if (!launch_slot_with_task(*slot, task))
- {
- LOG_ERROR("error while launching slot", task.data);
+ if (!launch_slot_with_task(*slot, task)) {
+ SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
break;
}
- }
- break;
+ } break;
case SERVER_TASK_TYPE_CANCEL: {
// release slot linked with the task id
- for (auto &slot : slots)
- {
- if (slot.id_task == task.id_target)
- {
+ for (auto &slot : slots) {
+ if (slot.id_task == task.id_target) {
slot.release();
break;
}
}
- }
- break;
+ } break;
case SERVER_TASK_TYPE_NEXT_RESPONSE: {
// do nothing
- }
- break;
+ } break;
case SERVER_TASK_TYPE_METRICS: {
json slots_data = json::array();
int n_idle_slots = 0;
int n_processing_slots = 0;
- for (server_slot &slot : slots)
- {
- json slot_data = get_formated_generation(slot);
- slot_data["id"] = slot.id;
- slot_data["id_task"] = slot.id_task;
- slot_data["state"] = slot.state;
- slot_data["prompt"] = slot.prompt;
- slot_data["next_token"] = {
- {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining},
- {"n_decoded", slot.n_decoded}, {"stopped_eos", slot.stopped_eos},
- {"stopped_word", slot.stopped_word}, {"stopped_limit", slot.stopped_limit},
- {"stopping_word", slot.stopping_word},
- };
-
- if (slot_data["state"] == SLOT_STATE_IDLE)
- {
- n_idle_slots++;
- }
- else
- {
+ for (server_slot &slot : slots) {
+ json slot_data = slot.to_json();
+
+ if (slot.is_processing()) {
n_processing_slots++;
+ } else {
+ n_idle_slots++;
}
slots_data.push_back(slot_data);
}
- LOG_INFO(
- "slot data",
- {{"id_task", task.id}, {"n_idle_slots", n_idle_slots}, {"n_processing_slots", n_processing_slots}});
-
- LOG_VERBOSE("slot data", {{"id_task", task.id},
- {"n_idle_slots", n_idle_slots},
- {"n_processing_slots", n_processing_slots},
- {"slots", slots_data}});
-
- server_task_result res;
- res.id = task.id;
- res.id_multi = task.id_multi;
- res.stop = true;
- res.error = false;
- res.data = {
- {"idle", n_idle_slots},
- {"processing", n_processing_slots},
- {"deferred", queue_tasks.queue_tasks_deferred.size()},
- {"t_start", metrics.t_start},
-
- {"n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
- {"t_tokens_generation_total", metrics.t_tokens_generation_total},
- {"n_tokens_predicted_total", metrics.n_tokens_predicted_total},
- {"t_prompt_processing_total", metrics.t_prompt_processing_total},
-
- {"n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
- {"t_prompt_processing", metrics.t_prompt_processing},
- {"n_tokens_predicted", metrics.n_tokens_predicted},
- {"t_tokens_generation", metrics.t_tokens_generation},
-
- {"kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
- {"kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
-
- {"slots", slots_data},
- };
+ SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
- if (json_value(task.data, "reset_bucket", false))
- {
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->slots_data = std::move(slots_data);
+ res->n_idle_slots = n_idle_slots;
+ res->n_processing_slots = n_processing_slots;
+ res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
+ res->t_start = metrics.t_start;
+
+ res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
+ res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
+
+ res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
+ res->t_prompt_processing_total = metrics.t_prompt_processing_total;
+ res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
+ res->t_tokens_generation_total = metrics.t_tokens_generation_total;
+
+ res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
+ res->t_prompt_processing = metrics.t_prompt_processing;
+ res->n_tokens_predicted = metrics.n_tokens_predicted;
+ res->t_tokens_generation = metrics.t_tokens_generation;
+
+ res->n_decode_total = metrics.n_decode_total;
+ res->n_busy_slots_total = metrics.n_busy_slots_total;
+
+ if (task.metrics_reset_bucket) {
metrics.reset_bucket();
}
- queue_results.send(res);
- }
- break;
+ queue_results.send(std::move(res));
+ } break;
case SERVER_TASK_TYPE_SLOT_SAVE: {
- int id_slot = task.data.at("id_slot");
+ int id_slot = task.slot_action.slot_id;
server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr)
- {
+ if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
- if (!slot->available())
- {
+ if (slot->is_processing()) {
// if requested slot is unavailable, we defer this task for processing later
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
queue_tasks.defer(task);
break;
}
@@ -1910,54 +2591,49 @@ struct server_context
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
- std::string filename = task.data.at("filename");
- std::string filepath = task.data.at("filepath");
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
const size_t nwrite =
- llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
+ llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
- server_task_result result;
- result.id = task.id;
- result.stop = true;
- result.error = false;
- result.data = json{{"id_slot", id_slot},
- {"filename", filename},
- {"n_saved", token_count}, // tokens saved
- {"n_written", nwrite}, // bytes written
- {"timings", {{"save_ms", t_save_ms}}}};
- queue_results.send(result);
- }
- break;
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->filename = filename;
+ res->is_save = true;
+ res->n_tokens = token_count;
+ res->n_bytes = nwrite;
+ res->t_ms = t_save_ms;
+ queue_results.send(std::move(res));
+ } break;
case SERVER_TASK_TYPE_SLOT_RESTORE: {
- int id_slot = task.data.at("id_slot");
+ int id_slot = task.slot_action.slot_id;
server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr)
- {
+ if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
- if (!slot->available())
- {
+ if (slot->is_processing()) {
// if requested slot is unavailable, we defer this task for processing later
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
queue_tasks.defer(task);
break;
}
const int64_t t_start = ggml_time_us();
- std::string filename = task.data.at("filename");
- std::string filepath = task.data.at("filepath");
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0;
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(),
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(),
slot->cache_tokens.size(), &token_count);
- if (nread == 0)
- {
+ if (nread == 0) {
slot->cache_tokens.resize(0);
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file",
ERROR_TYPE_INVALID_REQUEST);
@@ -1968,116 +2644,65 @@ struct server_context
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
- server_task_result result;
- result.id = task.id;
- result.stop = true;
- result.error = false;
- result.data = json{{"id_slot", id_slot},
- {"filename", filename},
- {"n_restored", token_count}, // tokens restored
- {"n_read", nread}, // bytes read
- {"timings", {{"restore_ms", t_restore_ms}}}};
- queue_results.send(result);
- }
- break;
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->filename = filename;
+ res->is_save = false;
+ res->n_tokens = token_count;
+ res->n_bytes = nread;
+ res->t_ms = t_restore_ms;
+ queue_results.send(std::move(res));
+ } break;
case SERVER_TASK_TYPE_SLOT_ERASE: {
- int id_slot = task.data.at("id_slot");
+ int id_slot = task.slot_action.slot_id;
server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr)
- {
+ if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
- if (!slot->available())
- {
+ if (slot->is_processing()) {
// if requested slot is unavailable, we defer this task for processing later
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
queue_tasks.defer(task);
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
- llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
+ llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
- server_task_result result;
- result.id = task.id;
- result.stop = true;
- result.error = false;
- result.data = json{{"id_slot", id_slot}, {"n_erased", n_erased}};
- queue_results.send(result);
- }
- break;
- }
- }
-
- void on_finish_multitask(const server_task_multi &multitask)
- {
- // all subtasks done == multitask is done
- server_task_result result;
- result.id = multitask.id;
- result.stop = true;
- result.error = false;
-
- // collect json results into one json result
- std::vector result_jsons;
- for (const auto &subres : multitask.results)
- {
- result_jsons.push_back(subres.data);
- result.error = result.error && subres.error;
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->n_erased = n_erased;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SET_LORA: {
+ params_base.lora_adapters = std::move(task.set_lora);
+ auto res = std::make_unique();
+ res->id = task.id;
+ queue_results.send(std::move(res));
+ } break;
}
- result.data = json{{"results", result_jsons}};
-
- queue_results.send(result);
}
- void update_slots()
- {
- if (system_need_update)
- {
- system_prompt_update();
- }
-
- // release slots
- for (auto &slot : slots)
- {
- if (slot.command == SLOT_COMMAND_RELEASE)
- {
- slot.state = SLOT_STATE_IDLE;
- slot.command = SLOT_COMMAND_NONE;
- slot.t_last_used = ggml_time_us();
-
- LOG_INFO("slot released", {{"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"n_ctx", n_ctx},
- {"n_past", slot.n_past},
- {"n_system_tokens", system_tokens.size()},
- {"n_cache_tokens", slot.cache_tokens.size()},
- {"truncated", slot.truncated}});
-
- queue_tasks.notify_slot_changed();
- }
- }
-
+ void update_slots() {
// check if all slots are idle
{
bool all_idle = true;
- for (auto &slot : slots)
- {
- if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE)
- {
+ for (auto &slot : slots) {
+ if (slot.is_processing()) {
all_idle = false;
break;
}
}
- if (all_idle)
- {
- LOG_INFO("all slots are idle", {});
- if (system_prompt.empty() && clean_kv_cache)
- {
+ if (all_idle) {
+ SRV_INF("%s", "all slots are idle\n");
+ if (clean_kv_cache) {
kv_cache_clear();
}
@@ -2086,224 +2711,188 @@ struct server_context
}
{
- LOG_VERBOSE("posting NEXT_RESPONSE", {});
-
- server_task task;
- task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
- task.id_target = -1;
+ SRV_DBG("%s", "posting NEXT_RESPONSE\n");
+ server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
+ task.id = queue_tasks.get_new_id();
queue_tasks.post(task);
}
// apply context-shift if needed
// TODO: simplify and improve
- for (server_slot &slot : slots)
- {
- if (slot.ga_n == 1)
- {
- if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1)
- {
- // Shift context
- const int n_keep = slot.params.n_keep + add_bos_token;
- const int n_left = (int)system_tokens.size() + slot.n_past - n_keep;
- const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
-
- LOG_INFO("slot context shift", {{"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"n_keep", n_keep},
- {"n_left", n_left},
- {"n_discard", n_discard},
- {"n_ctx", n_ctx},
- {"n_past", slot.n_past},
- {"n_system_tokens", system_tokens.size()},
- {"n_cache_tokens", slot.cache_tokens.size()}});
-
- llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_discard);
- llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past,
- -n_discard);
-
- if (slot.params.cache_prompt)
- {
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
- {
- slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
- }
+ for (server_slot &slot : slots) {
+ if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
+ if (!params_base.ctx_shift) {
+ // this check is redundant (for good)
+ // we should never get here, because generation should already stopped in process_token()
+ slot.release();
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+ continue;
+ }
- slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
- }
+ // Shift context
+ const int n_keep = slot.params.n_keep + add_bos_token;
+ const int n_left = slot.n_past - n_keep;
+ const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
+
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left,
+ n_discard);
- slot.n_past -= n_discard;
+ llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard);
+ llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
- slot.truncated = true;
+ if (slot.params.cache_prompt) {
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
+ slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
+ }
+
+ slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
}
+
+ slot.n_past -= n_discard;
+
+ slot.truncated = true;
}
}
// start populating the batch for this iteration
- llama_batch_clear(batch);
+ common_batch_clear(batch);
+
+ // track if given slot can be batched with slots already in the batch
+ server_slot *slot_batched = nullptr;
+
+ auto accept_special_token = [&](server_slot &slot, llama_token token) {
+ return params_base.special ||
+ slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
+ };
// frist, add sampled tokens from any ongoing sequences
- for (auto &slot : slots)
- {
- if (slot.state == SLOT_STATE_IDLE)
- {
+ for (auto &slot : slots) {
+ if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
- slot.i_batch = batch.n_tokens;
+ // check if we can batch this slot with the previous one
+ if (!slot_batched) {
+ slot_batched = &slot;
+ } else if (!slot_batched->can_batch_with(slot)) {
+ continue;
+ }
- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
+ slot.i_batch = batch.n_tokens;
- // TODO: we always have to take into account the "system_tokens"
- // this is not great and needs to be improved somehow
- llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, {slot.id + 1}, true);
+ common_batch_add(batch, slot.sampled, slot.n_past, {slot.id}, true);
slot.n_past += 1;
- if (slot.params.cache_prompt)
- {
+ if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
- LOG_VERBOSE("slot decode token", {{"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"n_ctx", n_ctx},
- {"n_past", slot.n_past},
- {"n_system_tokens", system_tokens.size()},
- {"n_cache_tokens", slot.cache_tokens.size()},
- {"truncated", slot.truncated}});
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
+ slot.n_ctx, slot.n_past, (int)slot.cache_tokens.size(), slot.truncated);
}
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
- // track if this is an embedding or non-embedding batch
- // if we've added sampled tokens above, we are in non-embedding mode
- // -1: none, 0: non-embedding, 1: embedding
- int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
-
// next, batch any pending prompts without exceeding n_batch
- if (params.cont_batching || batch.n_tokens == 0)
- {
- for (auto &slot : slots)
- {
+ if (params_base.cont_batching || batch.n_tokens == 0) {
+ for (auto &slot : slots) {
+ // check if we can batch this slot with the previous one
+ if (slot.is_processing()) {
+ if (!slot_batched) {
+ slot_batched = &slot;
+ } else if (!slot_batched->can_batch_with(slot)) {
+ continue;
+ }
+ }
+
// this slot still has a prompt to be processed
- if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)
- {
+ if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto &prompt_tokens = slot.prompt_tokens;
- // we haven't tokenized the prompt yet - do it now:
- if (prompt_tokens.empty())
- {
- LOG_VERBOSE("tokenizing prompt", {{"id_slot", slot.id}, {"id_task", slot.id_task}});
-
+ // TODO: maybe move branch to outside of this loop in the future
+ if (slot.state == SLOT_STATE_STARTED) {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
- if (slot.infill)
- {
- const bool add_bos = llama_should_add_bos_token(model);
- bool suff_rm_leading_spc = true;
- if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
- {
- params.input_suffix.erase(0, 1);
- suff_rm_leading_spc = false;
- }
-
- auto prefix_tokens = tokenize(slot.params.input_prefix, false);
- auto suffix_tokens = tokenize(slot.params.input_suffix, false);
-
- const int space_token = 29871; // TODO: this should not be hardcoded
- if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token)
- {
- suffix_tokens.erase(suffix_tokens.begin());
- }
+ slot.n_past = 0;
+ slot.n_prompt_tokens = prompt_tokens.size();
+ slot.state = SLOT_STATE_PROCESSING_PROMPT;
- prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
- suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx,
+ slot.params.n_keep, slot.n_prompt_tokens);
- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
- if (add_bos)
- {
- embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
+ // print prompt tokens (for debugging)
+ if (1) {
+ // first 16 tokens (avoid flooding logs)
+ for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) {
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i],
+ common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
- embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
-
- const llama_token middle_token = llama_token_middle(model);
- if (middle_token >= 0)
- {
- embd_inp.push_back(middle_token);
+ } else {
+ // all
+ for (int i = 0; i < (int)prompt_tokens.size(); i++) {
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i],
+ common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
-
- prompt_tokens = embd_inp;
- }
- else
- {
- prompt_tokens =
- tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
}
- slot.n_past = 0;
- slot.n_prompt_tokens = prompt_tokens.size();
-
- LOG_VERBOSE("prompt tokenized", {
- {"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"n_ctx", slot.n_ctx},
- {"n_keep", slot.params.n_keep},
- {"n_prompt_tokens", slot.n_prompt_tokens},
- {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(),
- prompt_tokens.cend())},
- });
-
// empty prompt passed -> release the slot and send empty response
- if (prompt_tokens.empty())
- {
- LOG_INFO("empty prompt - releasing slot",
- {{"id_slot", slot.id}, {"id_task", slot.id_task}});
+ if (prompt_tokens.empty()) {
+ SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
- slot.state = SLOT_STATE_PROCESSING;
- slot.command = SLOT_COMMAND_NONE;
slot.release();
slot.print_timings();
send_final_response(slot);
continue;
}
- if (slot.embedding)
- {
- // this prompt is too large to process - discard it
- if (slot.n_prompt_tokens > n_ubatch)
- {
- slot.state = SLOT_STATE_PROCESSING;
- slot.command = SLOT_COMMAND_NONE;
+ if (slot.is_non_causal()) {
+ if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size",
ERROR_TYPE_SERVER);
continue;
}
- }
- else
- {
- if (slot.params.n_keep < 0)
- {
+
+ if (slot.n_prompt_tokens > slot.n_ctx) {
+ slot.release();
+ send_error(slot, "input is larger than the max context size. skipping",
+ ERROR_TYPE_SERVER);
+ continue;
+ }
+ } else {
+ if (!params_base.ctx_shift) {
+ // if context shift is disabled, we make sure prompt size is smaller than KV size
+ // TODO: there should be a separate parameter that control prompt truncation
+ // context shift should be applied only during the generation phase
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
+ slot.release();
+ send_error(slot,
+ "the request exceeds the available context size. try increasing the "
+ "context size or enable context shift",
+ ERROR_TYPE_INVALID_REQUEST);
+ continue;
+ }
+ }
+ if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
- if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
- {
+ // if input prompt is too big, truncate it
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks =
(slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
- std::vector new_tokens(prompt_tokens.begin(),
- prompt_tokens.begin() + slot.params.n_keep);
+ llama_tokens new_tokens(prompt_tokens.begin(),
+ prompt_tokens.begin() + slot.params.n_keep);
new_tokens.insert(new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep +
@@ -2315,265 +2904,183 @@ struct server_context
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
- LOG_VERBOSE("input truncated",
- {
- {"id_slot", slot.id},
- {"id_task", slot.id_task},
- {"n_ctx", slot.n_ctx},
- {"n_keep", slot.params.n_keep},
- {"n_left", n_left},
- {"n_prompt_tokens", slot.n_prompt_tokens},
- {"prompt_tokens",
- tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
- });
+ SLT_WRN(slot,
+ "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n",
+ slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
- llama_sampling_reset(slot.ctx_sampling);
-
- if (!slot.params.cache_prompt)
- {
- slot.n_past_se = 0;
- slot.ga_i = 0;
- }
- else
- {
- GGML_ASSERT(slot.ga_n == 1);
-
+ if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
- slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
-
- // push the prompt into the sampling context (do not apply grammar)
- for (int i = 0; i < slot.n_past; ++i)
- {
- llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
+ slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
+
+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
+ if (params_base.n_cache_reuse > 0) {
+ size_t head_c = slot.n_past; // cache
+ size_t head_p = slot.n_past; // current prompt
+
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n",
+ params_base.n_cache_reuse, slot.n_past);
+
+ while (head_c < slot.cache_tokens.size() && head_p < prompt_tokens.size()) {
+
+ size_t n_match = 0;
+ while (head_c + n_match < slot.cache_tokens.size() &&
+ head_p + n_match < prompt_tokens.size() &&
+ slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
+
+ n_match++;
+ }
+
+ if (n_match >= (size_t)params_base.n_cache_reuse) {
+ SLT_INF(slot,
+ "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> "
+ "[%zu, %zu)\n",
+ n_match, head_c, head_c + n_match, head_p, head_p + n_match);
+ // for (size_t i = head_p; i < head_p + n_match; i++) {
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i],
+ // common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ // }
+
+ const int64_t kv_shift = (int64_t)head_p - (int64_t)head_c;
+
+ llama_kv_cache_seq_rm(ctx, slot.id, head_p, head_c);
+ llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
+
+ for (size_t i = 0; i < n_match; i++) {
+ slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
+ slot.n_past++;
+ }
+
+ head_c += n_match;
+ head_p += n_match;
+ } else {
+ head_c += 1;
+ }
+ }
+
+ SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
}
}
}
- if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0)
- {
+ if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
- LOG_INFO("we have to evaluate at least 1 token to generate logits",
- {{"id_slot", slot.id}, {"id_task", slot.id_task}});
+ SLT_WRN(slot,
+ "need to evaluate at least 1 token to generate logits, n_past = %d, "
+ "n_prompt_tokens = %d\n",
+ slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
- if (slot.ga_i > 0)
- {
- slot.n_past_se--;
- }
}
slot.n_prompt_tokens_processed = 0;
}
- if (slot.embedding)
- {
+ // non-causal tasks require to fit the entire prompt in the physical batch
+ if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
- if (batch.n_tokens + slot.n_prompt_tokens > n_batch)
- {
+ if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
}
- // check that we are in the right batch_type, if not defer the slot
- bool slot_type = slot.embedding ? 1 : 0;
- if (batch_type == -1)
- {
- batch_type = slot_type;
- }
- else if (batch_type != slot_type)
- {
- continue;
- }
-
// keep only the common part
- int p0 = (int)system_tokens.size() + slot.n_past;
- if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1))
- {
+ if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
- llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
-
- p0 = (int)system_tokens.size();
- if (p0 != 0)
- {
- // copy over the system prompt when there is one
- llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
- }
+ llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
- // there is no common part left (except for the system prompt)
+ // there is no common part left
slot.n_past = 0;
- slot.n_past_se = 0;
- slot.ga_i = 0;
- // TODO: is the system prompt ever in the sampling context?
- llama_sampling_reset(slot.ctx_sampling);
}
+ SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
+
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
- LOG_INFO("kv cache rm [p0, end)", {{"id_slot", slot.id}, {"id_task", slot.id_task}, {"p0", p0}});
-
- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
-
- int32_t ga_i = slot.ga_i;
- int32_t ga_n = slot.ga_n;
- int32_t ga_w = slot.ga_w;
-
// add prompt tokens for processing in the current batch
- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past)
- {
- if (slot.ga_n != 1)
- {
- while (slot_npast >= ga_i + ga_w)
- {
- const int bd = (ga_w / ga_n) * (ga_n - 1);
- slot_npast -= bd;
- ga_i += ga_w / ga_n;
- }
- }
+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
+ // without pooling, we want to output the embeddings for all the tokens in the batch
+ const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING &&
+ llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
- llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast,
- {slot.id + 1}, false);
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, {slot.id}, need_embd);
- if (slot.params.cache_prompt)
- {
+ if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
}
slot.n_prompt_tokens_processed++;
- slot_npast++;
+ slot.n_past++;
}
- LOG_VERBOSE("prompt processing progress",
- {
- {"id_slot", slot.id},
- {"n_past", slot.n_past},
- {"n_ctx", n_ctx},
- {"n_tokens", batch.n_tokens},
- {"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
- });
-
- // entire prompt has been processed - start decoding new tokens
- if (slot.n_past == slot.n_prompt_tokens)
- {
- slot.state = SLOT_STATE_PROCESSING;
- slot.command = SLOT_COMMAND_NONE;
+ SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n",
+ slot.n_past, batch.n_tokens, (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
+
+ // entire prompt has been processed
+ if (slot.n_past == slot.n_prompt_tokens) {
+ slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
+ common_sampler_reset(slot.smpl);
+
+ // Process all prompt tokens through sampler system
+ for (int i = 0; i < slot.n_prompt_tokens; ++i) {
+ common_sampler_accept(slot.smpl, prompt_tokens[i], false);
+ }
+
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
- LOG_VERBOSE("prompt done", {
- {"id_slot", slot.id},
- {"n_past", slot.n_past},
- {"n_ctx", n_ctx},
- {"n_tokens", batch.n_tokens},
- });
+ SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
}
}
- if (batch.n_tokens >= n_batch)
- {
+ if (batch.n_tokens >= n_batch) {
break;
}
}
}
- if (batch.n_tokens == 0)
- {
- LOG_VERBOSE("no tokens to decode", {});
+ if (batch.n_tokens == 0) {
+ SRV_WRN("%s", "no tokens to decode\n");
return;
}
- LOG_VERBOSE("decoding batch", {
- {"n_tokens", batch.n_tokens},
- });
+ SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
- // make sure we're in the right embedding mode
- llama_set_embeddings(ctx, batch_type == 1);
+ if (slot_batched) {
+ // make sure we're in the right embedding mode
+ llama_set_embeddings(ctx, slot_batched->is_non_causal());
+ // apply lora, only need to do it once per batch
+ common_set_adapter_lora(ctx, slot_batched->lora);
+ }
// process the created batch of tokens
- for (int32_t i = 0; i < batch.n_tokens; i += n_batch)
- {
+ for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
- for (auto &slot : slots)
- {
- if (slot.ga_n != 1)
- {
- // context extension via Self-Extend
- // TODO: simplify and/or abstract this
- while (slot.n_past_se >= slot.ga_i + slot.ga_w)
- {
- const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
- const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
- const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
-
- LOG_TEE("\n");
- LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd,
- slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
- LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd,
- slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n,
- (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
- LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w,
- slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd,
- slot.n_past_se + ib * bd + dd);
-
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
- llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,
- slot.ga_n);
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,
- slot.n_past_se + ib * bd, dd);
-
- slot.n_past_se -= bd;
-
- slot.ga_i += slot.ga_w / slot.ga_n;
-
- LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se,
- slot.ga_i);
- }
-
- slot.n_past_se += n_tokens;
- }
- }
-
llama_batch batch_view = {
- n_tokens,
- batch.token + i,
- nullptr,
- batch.pos + i,
- batch.n_seq_id + i,
- batch.seq_id + i,
- batch.logits + i,
- 0,
- 0,
- 0, // unused
+ n_tokens, batch.token + i, nullptr, batch.pos + i,
+ batch.n_seq_id + i, batch.seq_id + i, batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
+ metrics.on_decoded(slots);
- if (ret != 0)
- {
- if (n_batch == 1 || ret < 0)
- {
+ if (ret != 0) {
+ if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
- LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size",
- {
- {"i", i},
- {"n_batch", ret},
- {"ret", ret},
- });
- for (auto &slot : slots)
- {
- slot.state = SLOT_STATE_PROCESSING;
- slot.command = SLOT_COMMAND_NONE;
+ SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i "
+ "= %d, n_batch = %d, ret = %d\n",
+ i, n_batch, ret);
+ for (auto &slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
@@ -2584,223 +3091,181 @@ struct server_context
n_batch /= 2;
i -= n_batch;
- LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try "
- "increasing it via the context size or enable defragmentation",
- {
- {"i", i},
- {"n_batch", n_batch},
- {"ret", ret},
- });
+ SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing "
+ "it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n",
+ i, n_batch, ret);
continue; // continue loop of n_batch
}
- for (auto &slot : slots)
- {
- if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens))
- {
+ for (auto &slot : slots) {
+ if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
continue; // continue loop of slots
}
- // prompt evaluated for embedding
- if (slot.embedding)
- {
- send_embedding(slot, batch_view);
- slot.release();
- slot.i_batch = -1;
+ if (slot.state == SLOT_STATE_DONE_PROMPT) {
+ if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
+ // prompt evaluated for embedding
+ send_embedding(slot, batch_view);
+ slot.release();
+ slot.i_batch = -1;
+ continue; // continue loop of slots
+ }
+
+ if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
+ send_rerank(slot, batch_view);
+ slot.release();
+ slot.i_batch = -1;
+ continue; // continue loop of slots
+ }
+
+ // prompt evaluated for next-token prediction
+ slot.state = SLOT_STATE_GENERATING;
+ } else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
- completion_token_output result;
- const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
+ const int tok_idx = slot.i_batch - i;
+
+ llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
+
+ slot.i_batch = -1;
- llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
+ common_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
- if (slot.n_decoded == 1)
- {
- slot.t_start_generation = ggml_time_us();
+
+ const int64_t t_current = ggml_time_us();
+
+ if (slot.n_decoded == 1) {
+ slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
- llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false};
- result.tok = id;
-
- const size_t n_probs = std::min(cur_p.size, (size_t)slot.sparams.n_probs);
- if (n_probs > 0)
- {
- const size_t n_valid = slot.ctx_sampling->n_valid;
+ slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
- // Make sure at least n_probs top tokens are at the front of the vector:
- if (slot.sparams.temp == 0.0f && n_probs > n_valid)
- {
- llama_sample_top_k(ctx, &cur_p, n_probs, 0);
- }
+ completion_token_output result;
+ result.tok = id;
+ result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
+ result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
- if (slot.sparams.temp == 0.0f)
- {
- // With greedy sampling the probabilities have possibly not been calculated.
- for (size_t i = 0; i < n_probs; ++i)
- {
- result.probs.push_back({cur_p.data[i].id, i == 0 ? 1.0f : 0.0f});
- }
- }
- else
- {
- for (size_t i = 0; i < n_probs; ++i)
- {
- result.probs.push_back({
- cur_p.data[i].id,
- i >= n_valid
- ? 0.0f
- : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
- });
- }
- }
+ if (slot.params.sampling.n_probs > 0) {
+ populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
}
- if (!process_token(result, slot))
- {
+ if (!process_token(result, slot)) {
+ // release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
+ continue;
}
-
- slot.i_batch = -1;
}
- }
- LOG_VERBOSE("run slots completed", {});
- }
+ // do speculative decoding
+ for (auto &slot : slots) {
+ if (!slot.is_processing() || !slot.can_speculate()) {
+ continue;
+ }
- json model_meta() const
- {
- return json{
- {"vocab_type", llama_vocab_type(model)}, {"n_vocab", llama_n_vocab(model)},
- {"n_ctx_train", llama_n_ctx_train(model)}, {"n_embd", llama_n_embd(model)},
- {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)},
- };
- }
-};
+ if (slot.state != SLOT_STATE_GENERATING) {
+ continue;
+ }
-// parse the given jparams (see de.kherud.llama.args.ModelParameters#toString()) from JSON to the required C++ struct.
-static void server_params_parse(json jparams, gpt_params ¶ms)
-{
- gpt_params default_params;
-
- params.seed = json_value(jparams, "seed", default_params.seed);
- params.n_threads = json_value(jparams, "n_threads", default_params.n_threads);
- params.n_threads_draft = json_value(jparams, "n_threads_draft", default_params.n_threads_draft);
- params.n_threads_batch = json_value(jparams, "n_threads_batch", default_params.n_threads_batch);
- params.n_threads_batch_draft = json_value(jparams, "n_threads_batch_draft", default_params.n_threads_batch_draft);
- params.n_predict = json_value(jparams, "n_predict", default_params.n_predict);
- params.n_ctx = json_value(jparams, "n_ctx", default_params.n_ctx);
- params.n_batch = json_value(jparams, "n_batch", default_params.n_batch);
- params.n_ubatch = json_value(jparams, "n_ubatch", default_params.n_ubatch);
- params.n_keep = json_value(jparams, "n_keep", default_params.n_keep);
- params.n_draft = json_value(jparams, "n_draft", default_params.n_draft);
- params.n_chunks = json_value(jparams, "n_chunks", default_params.n_chunks);
- params.n_parallel = json_value(jparams, "n_parallel", default_params.n_parallel);
- params.n_sequences = json_value(jparams, "n_sequences", default_params.n_sequences);
- params.p_split = json_value(jparams, "p_split", default_params.p_split);
- params.grp_attn_n = json_value(jparams, "grp_attn_n", default_params.grp_attn_n);
- params.grp_attn_w = json_value(jparams, "grp_attn_w", default_params.grp_attn_w);
- params.n_print = json_value(jparams, "n_print", default_params.n_print);
- params.rope_freq_base = json_value(jparams, "rope_freq_base", default_params.rope_freq_base);
- params.rope_freq_scale = json_value(jparams, "rope_freq_scale", default_params.rope_freq_scale);
- params.yarn_ext_factor = json_value(jparams, "yarn_ext_factor", default_params.yarn_ext_factor);
- params.yarn_attn_factor = json_value(jparams, "yarn_attn_factor", default_params.yarn_attn_factor);
- params.yarn_beta_fast = json_value(jparams, "yarn_beta_fast", default_params.yarn_beta_fast);
- params.yarn_beta_slow = json_value(jparams, "yarn_beta_slow", default_params.yarn_beta_slow);
- params.yarn_orig_ctx = json_value(jparams, "yarn_orig_ctx", default_params.yarn_orig_ctx);
- params.defrag_thold = json_value(jparams, "defrag_thold", default_params.defrag_thold);
- params.numa = json_value(jparams, "numa", default_params.numa);
- params.rope_scaling_type = json_value(jparams, "rope_scaling_type", default_params.rope_scaling_type);
- params.pooling_type = json_value(jparams, "pooling_type", default_params.pooling_type);
- params.model = json_value(jparams, "model", default_params.model);
- params.model_draft = json_value(jparams, "model_draft", default_params.model_draft);
- params.model_alias = json_value(jparams, "model_alias", default_params.model_alias);
- params.model_url = json_value(jparams, "model_url", default_params.model_url);
- params.hf_repo = json_value(jparams, "hf_repo", default_params.hf_repo);
- params.hf_file = json_value(jparams, "hf_file", default_params.hf_file);
- params.prompt = json_value(jparams, "prompt", default_params.prompt);
- params.prompt_file = json_value(jparams, "prompt_file", default_params.prompt_file);
- params.path_prompt_cache = json_value(jparams, "path_prompt_cache", default_params.path_prompt_cache);
- params.input_prefix = json_value(jparams, "input_prefix", default_params.input_prefix);
- params.input_suffix = json_value(jparams, "input_suffix", default_params.input_suffix);
- params.antiprompt = json_value(jparams, "antiprompt", default_params.antiprompt);
- params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static);
- params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic);
- params.logits_file = json_value(jparams, "logits_file", default_params.logits_file);
- params.lora_adapter = json_value(jparams, "lora_adapter", default_params.lora_adapter);
- params.embedding = json_value(jparams, "embedding", default_params.embedding);
- params.escape = json_value(jparams, "escape", default_params.escape);
- params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching);
- params.flash_attn = json_value(jparams, "flash_attn", default_params.flash_attn);
- params.input_prefix_bos = json_value(jparams, "input_prefix_bos", default_params.input_prefix_bos);
- params.ignore_eos = json_value(jparams, "ignore_eos", default_params.ignore_eos);
- params.use_mmap = json_value(jparams, "use_mmap", default_params.use_mmap);
- params.use_mlock = json_value(jparams, "use_mlock", default_params.use_mlock);
- params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload);
- params.system_prompt = json_value(jparams, "system_prompt", default_params.system_prompt);
- params.chat_template = json_value(jparams, "chat_template", default_params.chat_template);
-
- if (jparams.contains("n_gpu_layers"))
- {
- if (llama_supports_gpu_offload())
- {
- params.n_gpu_layers = json_value(jparams, "n_gpu_layers", default_params.n_gpu_layers);
- params.n_gpu_layers_draft = json_value(jparams, "n_gpu_layers_draft", default_params.n_gpu_layers_draft);
- }
- else
- {
- LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
- "See main README.md for information on enabling GPU BLAS support",
- {{"n_gpu_layers", params.n_gpu_layers}});
- }
- }
+ // determine the max draft that fits the current slot state
+ int n_draft_max = slot.params.speculative.n_max;
- if (jparams.contains("split_mode"))
- {
- params.split_mode = json_value(jparams, "split_mode", default_params.split_mode);
-// todo: the definition checks here currently don't work due to cmake visibility reasons
-#ifndef GGML_USE_CUDA
- fprintf(stderr, "warning: llama.cpp was compiled without CUDA. Setting the split mode has no effect.\n");
-#endif
- }
+ // note: n_past is not yet increased for the `id` token sampled above
+ // also, need to leave space for 1 extra token to allow context shifts
+ n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
- if (jparams.contains("tensor_split"))
- {
-#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
- std::vector tensor_split = jparams["tensor_split"].get>();
- GGML_ASSERT(tensor_split.size() <= llama_max_devices());
+ if (slot.n_remaining > 0) {
+ n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
+ }
- for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device)
- {
- if (i_device < tensor_split.size())
- {
- params.tensor_split[i_device] = tensor_split.at(i_device);
- }
- else
- {
- params.tensor_split[i_device] = 0.0f;
+ SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
+
+ if (n_draft_max < slot.params.speculative.n_min) {
+ SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n",
+ n_draft_max, slot.params.speculative.n_min);
+
+ continue;
+ }
+
+ llama_token id = slot.sampled;
+
+ struct common_speculative_params params_spec;
+ params_spec.n_draft = n_draft_max;
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
+ params_spec.p_min = slot.params.speculative.p_min;
+
+ llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
+
+ // ignore small drafts
+ if (slot.params.speculative.n_min > (int)draft.size()) {
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
+
+ continue;
+ }
+
+ // construct the speculation batch
+ common_batch_clear(slot.batch_spec);
+ common_batch_add(slot.batch_spec, id, slot.n_past, {slot.id}, true);
+
+ for (size_t i = 0; i < draft.size(); ++i) {
+ common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, {slot.id}, true);
+ }
+
+ SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
+
+ llama_decode(ctx, slot.batch_spec);
+
+ // the accepted tokens from the speculation
+ const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
+
+ slot.n_past += ids.size();
+ slot.n_decoded += ids.size();
+
+ slot.cache_tokens.push_back(id);
+ slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
+
+ llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+
+ for (size_t i = 0; i < ids.size(); ++i) {
+ completion_token_output result;
+
+ result.tok = ids[i];
+ result.text_to_send =
+ common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
+ result.prob = 1.0f; // set later
+
+ // TODO: set result.probs
+
+ if (!process_token(result, slot)) {
+ // release slot because of stop condition
+ slot.release();
+ slot.print_timings();
+ send_final_response(slot);
+ metrics.on_prediction(slot);
+ break;
+ }
+ }
+
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int)ids.size() - 1, (int)draft.size(),
+ slot.n_past);
}
}
-#else
- LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {});
-#endif // GGML_USE_CUDA
- }
- if (jparams.contains("main_gpu"))
- {
-#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
- params.main_gpu = json_value(jparams, "main_gpu", default_params.main_gpu);
-#else
- LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {});
-#endif
+ SRV_DBG("%s", "run slots completed\n");
}
- gpt_params_handle_model_default(params);
-}
+ json model_meta() const {
+ return json{
+ {"vocab_type", llama_vocab_type(vocab)}, {"n_vocab", llama_vocab_n_tokens(vocab)},
+ {"n_ctx_train", llama_model_n_ctx_train(model)}, {"n_embd", llama_model_n_embd(model)},
+ {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)},
+ };
+ }
+};
diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp
index 7de7eac4..603424b4 100644
--- a/src/main/cpp/utils.hpp
+++ b/src/main/cpp/utils.hpp
@@ -1,203 +1,360 @@
#pragma once
+#include "base64.hpp"
#include "common.h"
#include "llama.h"
+#include "log.h"
-#include "json.hpp"
+#ifndef NDEBUG
+// crash the server in debug mode, otherwise send an http 500 error
+#define CPPHTTPLIB_NO_EXCEPTIONS 1
+#endif
+// increase max payload length to allow use of larger context size
+#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
+// #include "httplib.h"
+
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
+#include "nlohmann/json.hpp"
+
+#include "chat.h"
+#include
#include
#include
#include
#include
-#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
using json = nlohmann::ordered_json;
-// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
-enum error_type
-{
- ERROR_TYPE_INVALID_REQUEST,
- ERROR_TYPE_AUTHENTICATION,
- ERROR_TYPE_SERVER,
- ERROR_TYPE_NOT_FOUND,
- ERROR_TYPE_PERMISSION,
- ERROR_TYPE_UNAVAILABLE, // custom error
- ERROR_TYPE_NOT_SUPPORTED, // custom error
-};
-
-extern bool log_json;
-extern std::function log_callback;
-
-#if SERVER_VERBOSE
-#define LOG_VERBOSE(MSG, ...) \
- do \
- { \
- server_log(GGML_LOG_LEVEL_DEBUG, __func__, __LINE__, MSG, __VA_ARGS__); \
- } while (0)
-#else
-#define LOG_VERBOSE(MSG, ...)
-#endif
-
-#define LOG_ERROR(MSG, ...) server_log(GGML_LOG_LEVEL_ERROR, __func__, __LINE__, MSG, __VA_ARGS__)
-#define LOG_WARNING(MSG, ...) server_log(GGML_LOG_LEVEL_WARN, __func__, __LINE__, MSG, __VA_ARGS__)
-#define LOG_INFO(MSG, ...) server_log(GGML_LOG_LEVEL_INFO, __func__, __LINE__, MSG, __VA_ARGS__)
-
-static inline void server_log(ggml_log_level level, const char *function, int line, const char *message,
- const json &extra);
-
-template static T json_value(const json &body, const std::string &key, const T &default_value)
-{
+#define SLT_INF(slot, fmt, ...) \
+ LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_WRN(slot, fmt, ...) \
+ LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_ERR(slot, fmt, ...) \
+ LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+#define SLT_DBG(slot, fmt, ...) \
+ LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
+
+#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+template static T json_value(const json &body, const std::string &key, const T &default_value) {
// Fallback null to default value
- if (body.contains(key) && !body.at(key).is_null())
- {
- try
- {
+ if (body.contains(key) && !body.at(key).is_null()) {
+ try {
return body.at(key);
- }
- catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &)
- {
- std::stringstream ss;
- ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name()
- << "', using default value.";
- LOG_WARNING(ss.str().c_str(), body);
+ } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
+ LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(),
+ json(default_value).type_name());
return default_value;
}
- }
- else
- {
+ } else {
return default_value;
}
}
-static const char *log_level_to_string(ggml_log_level level)
-{
- switch (level)
- {
- case GGML_LOG_LEVEL_ERROR:
- return "ERROR";
- case GGML_LOG_LEVEL_WARN:
- return "WARN";
- default:
- case GGML_LOG_LEVEL_INFO:
- return "INFO";
- case GGML_LOG_LEVEL_DEBUG:
- return "DEBUG";
+const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
+
+//
+// tokenizer and input processing utils
+//
+
+static bool json_is_array_of_numbers(const json &data) {
+ if (data.is_array()) {
+ for (const auto &e : data) {
+ if (!e.is_number_integer()) {
+ return false;
+ }
+ }
+ return true;
}
+ return false;
}
-static inline void server_log(ggml_log_level level, const char *function, int line, const char *message,
- const json &extra)
-{
- std::stringstream ss_tid;
- ss_tid << std::this_thread::get_id();
-
- if (log_json)
- {
- json log = json{
- {"msg", message},
-#if SERVER_VERBOSE
- {"ts", time(nullptr)}, {"level", log_level_to_string(level)}, {"tid", ss_tid.str()}, {"function", function},
- {"line", line},
-#endif
- };
-
- if (!extra.empty())
- {
- log.merge_patch(extra);
+// is array having BOTH numbers & strings?
+static bool json_is_array_of_mixed_numbers_strings(const json &data) {
+ bool seen_string = false;
+ bool seen_number = false;
+ if (data.is_array()) {
+ for (const auto &e : data) {
+ seen_string |= e.is_string();
+ seen_number |= e.is_number_integer();
+ if (seen_number && seen_string) {
+ return true;
+ }
}
+ }
+ return false;
+}
- auto dump = log.dump(-1, ' ', false, json::error_handler_t::replace);
- if (log_callback == nullptr)
- {
- printf("%s\n", dump.c_str());
+// get value by path(key1 / key2)
+static json json_get_nested_values(const std::vector &paths, const json &js) {
+ json result = json::object();
+
+ for (const std::string &path : paths) {
+ json current = js;
+ const auto keys = string_split(path, /*separator*/ '/');
+ bool valid_path = true;
+ for (const std::string &k : keys) {
+ if (valid_path && current.is_object() && current.contains(k)) {
+ current = current[k];
+ } else {
+ valid_path = false;
+ }
}
- else
- {
- log_callback(level, dump.c_str(), nullptr);
+ if (valid_path) {
+ result[path] = current;
}
}
- else
- {
- std::stringstream ss;
- ss << message;
-
- if (!extra.empty())
- {
- for (const auto &el : extra.items())
- {
- const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
- ss << " " << el.key() << "=" << value;
+ return result;
+}
+
+/**
+ * this handles 2 cases:
+ * - only string, example: "string"
+ * - mixed string and tokens, example: [12, 34, "string", 56, 78]
+ */
+static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_prompt, bool add_special,
+ bool parse_special) {
+ // If `add_bos` is true, we only add BOS, when json_prompt is a string,
+ // or the first element of the json_prompt array is a string.
+ llama_tokens prompt_tokens;
+
+ if (json_prompt.is_array()) {
+ bool first = true;
+ for (const auto &p : json_prompt) {
+ if (p.is_string()) {
+ auto s = p.template get();
+
+ llama_tokens p;
+ if (first) {
+ p = common_tokenize(vocab, s, add_special, parse_special);
+ first = false;
+ } else {
+ p = common_tokenize(vocab, s, false, parse_special);
+ }
+
+ prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
+ } else {
+ if (first) {
+ first = false;
+ }
+
+ prompt_tokens.push_back(p.template get());
}
}
+ } else {
+ auto s = json_prompt.template get();
+ prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
+ }
-#if SERVER_VERBOSE
- ss << " | ts " << time(nullptr) << " | tid " << ss_tid.str() << " | " << function << " line " << line;
-#endif
+ return prompt_tokens;
+}
- const std::string str = ss.str();
- if (log_callback == nullptr)
- {
- printf("[%4s] %.*s\n", log_level_to_string(level), (int)str.size(), str.data());
+/**
+ * break the input "prompt" object into multiple prompt if needed, then tokenize them
+ * this supports these cases:
+ * - "prompt": "string"
+ * - "prompt": [12, 34, 56]
+ * - "prompt": [12, 34, "string", 56, 78]
+ * and multiple prompts (multi-tasks):
+ * - "prompt": ["string1", "string2"]
+ * - "prompt": ["string1", [12, 34, 56]]
+ * - "prompt": [[12, 34, 56], [78, 90, 12]]
+ * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
+ */
+static std::vector tokenize_input_prompts(const llama_vocab *vocab, const json &json_prompt,
+ bool add_special, bool parse_special) {
+ std::vector result;
+ if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
+ // string or mixed
+ result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
+ } else if (json_is_array_of_numbers(json_prompt)) {
+ // array of tokens
+ result.push_back(json_prompt.get());
+ } else if (json_prompt.is_array()) {
+ // array of prompts
+ result.reserve(json_prompt.size());
+ for (const auto &p : json_prompt) {
+ if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
+ result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
+ } else if (json_is_array_of_numbers(p)) {
+ // array of tokens
+ result.push_back(p.get());
+ } else {
+ throw std::runtime_error(
+ "element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
+ }
}
- else
- {
- log_callback(level, str.c_str(), nullptr);
+ } else {
+ throw std::runtime_error(
+ "\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
+ }
+ if (result.empty()) {
+ throw std::runtime_error("\"prompt\" must not be empty");
+ }
+ return result;
+}
+
+// return the last index of character that can form a valid string
+// if the last character is potentially cut in half, return the index before the cut
+// if validate_utf8(text) == text.size(), then the whole text is valid utf8
+static size_t validate_utf8(const std::string &text) {
+ size_t len = text.size();
+ if (len == 0)
+ return 0;
+
+ // Check the last few bytes to see if a multi-byte character is cut off
+ for (size_t i = 1; i <= 4 && i <= len; ++i) {
+ unsigned char c = text[len - i];
+ // Check for start of a multi-byte sequence from the end
+ if ((c & 0xE0) == 0xC0) {
+ // 2-byte character start: 110xxxxx
+ // Needs at least 2 bytes
+ if (i < 2)
+ return len - i;
+ } else if ((c & 0xF0) == 0xE0) {
+ // 3-byte character start: 1110xxxx
+ // Needs at least 3 bytes
+ if (i < 3)
+ return len - i;
+ } else if ((c & 0xF8) == 0xF0) {
+ // 4-byte character start: 11110xxx
+ // Needs at least 4 bytes
+ if (i < 4)
+ return len - i;
}
}
- fflush(stdout);
+
+ // If no cut-off multi-byte character is found, return full length
+ return len;
}
//
-// chat template utils
+// template utils
//
-// Format given chat. If tmpl is empty, we take the template from model metadata
-inline std::string format_chat(const struct llama_model *model, const std::string &tmpl,
- const std::vector &messages)
-{
- std::vector chat;
+// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
+static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_tokens &query, const llama_tokens &doc) {
+ llama_tokens result;
- for (size_t i = 0; i < messages.size(); ++i)
- {
- const auto &curr_msg = messages[i];
+ result.reserve(doc.size() + query.size() + 4);
+ result.push_back(llama_vocab_bos(vocab));
+ result.insert(result.end(), query.begin(), query.end());
+ result.push_back(llama_vocab_eos(vocab));
+ result.push_back(llama_vocab_sep(vocab));
+ result.insert(result.end(), doc.begin(), doc.end());
+ result.push_back(llama_vocab_eos(vocab));
- std::string role = json_value(curr_msg, "role", std::string(""));
+ return result;
+}
- std::string content;
- if (curr_msg.contains("content"))
- {
- if (curr_msg["content"].is_string())
- {
- content = curr_msg["content"].get();
- }
- else if (curr_msg["content"].is_array())
- {
- for (const auto &part : curr_msg["content"])
- {
- if (part.contains("text"))
- {
- content += "\n" + part["text"].get();
- }
- }
- }
- else
- {
- throw std::runtime_error(
- "Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
- }
- }
- else
- {
- throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
+// format infill task
+static llama_tokens format_infill(const llama_vocab *vocab, const json &input_prefix, const json &input_suffix,
+ const json &input_extra, const int n_batch, const int n_predict, const int n_ctx,
+ const bool spm_infill, const llama_tokens &tokens_prompt) {
+ // TODO: optimize this block by reducing memory allocations and movement
+
+ // use FIM repo-level pattern:
+ // ref: https://arxiv.org/pdf/2409.12186
+ //
+ // [FIM_REP]myproject
+ // [FIM_SEP]filename0
+ // extra chunk 0
+ // [FIM_SEP]filename1
+ // extra chunk 1
+ // ...
+ // [FIM_SEP]filename
+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
+ //
+ llama_tokens extra_tokens;
+ extra_tokens.reserve(n_ctx);
+
+ auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
+ auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
+
+ if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
+ // TODO: make project name an input
+ static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
+
+ extra_tokens.push_back(llama_vocab_fim_rep(vocab));
+ extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
+ }
+ for (const auto &chunk : input_extra) {
+ // { "text": string, "filename": string }
+ const std::string text = json_value(chunk, "text", std::string());
+ const std::string filename = json_value(chunk, "filename", std::string("tmp"));
+
+ if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
+ const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
+
+ extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
+ extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+ } else {
+ // chunk separator in binary form to avoid confusing the AI
+ static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70,
+ 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
+ static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
+
+ extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
- chat.push_back({role, content});
+ const auto chunk_tokens = common_tokenize(vocab, text, false, false);
+ extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
+ }
+
+ if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
+ // TODO: current filename
+ static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
+
+ extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
+ extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
- auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
- LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
- return formatted_chat;
+ // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
+ const int n_prefix_take = std::min(tokens_prefix.size(), 3 * (n_batch / 4));
+ const int n_suffix_take =
+ std::min(tokens_suffix.size(), std::max(0, (n_batch / 4) - (2 + tokens_prompt.size())));
+
+ SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take,
+ (n_prefix_take + n_suffix_take));
+
+ // fill the rest of the context with extra chunks
+ const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch)-2 * n_predict), extra_tokens.size());
+
+ tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
+ tokens_suffix.resize(n_suffix_take);
+
+ tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
+ tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
+ tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
+
+ auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
+ auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
+
+ if (llama_vocab_get_add_bos(vocab)) {
+ embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
+ }
+
+ SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int)extra_tokens.size());
+
+ // put the extra context before the FIM prefix
+ embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
+
+ embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
+ embd_inp.push_back(llama_vocab_fim_mid(vocab));
+
+ return embd_inp;
}
//
@@ -208,13 +365,9 @@ static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
-static inline bool is_base64(uint8_t c)
-{
- return (isalnum(c) || (c == '+') || (c == '/'));
-}
+static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); }
-static inline std::vector base64_decode(const std::string &encoded_string)
-{
+static inline std::vector base64_decode(const std::string &encoded_string) {
int i = 0;
int j = 0;
int in_ = 0;
@@ -226,14 +379,11 @@ static inline std::vector base64_decode(const std::string &encoded_stri
std::vector ret;
- while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
- {
+ while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_];
in_++;
- if (i == 4)
- {
- for (i = 0; i < 4; i++)
- {
+ if (i == 4) {
+ for (i = 0; i < 4; i++) {
char_array_4[i] = base64_chars.find(char_array_4[i]);
}
@@ -241,8 +391,7 @@ static inline std::vector base64_decode(const std::string &encoded_stri
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
- for (i = 0; (i < 3); i++)
- {
+ for (i = 0; (i < 3); i++) {
ret.push_back(char_array_3[i]);
}
@@ -250,15 +399,12 @@ static inline std::vector base64_decode(const std::string &encoded_stri
}
}
- if (i)
- {
- for (j = i; j < 4; j++)
- {
+ if (i) {
+ for (j = i; j < 4; j++) {
char_array_4[j] = 0;
}
- for (j = 0; j < 4; j++)
- {
+ for (j = 0; j < 4; j++) {
char_array_4[j] = base64_chars.find(char_array_4[j]);
}
@@ -266,8 +412,7 @@ static inline std::vector base64_decode(const std::string &encoded_stri
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
- for (j = 0; j < i - 1; j++)
- {
+ for (j = 0; j < i - 1; j++) {
ret.push_back(char_array_3[j]);
}
}
@@ -279,8 +424,7 @@ static inline std::vector base64_decode(const std::string &encoded_stri
// random string / id
//
-static std::string random_string()
-{
+static std::string random_string() {
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
std::random_device rd;
@@ -288,63 +432,30 @@ static std::string random_string()
std::string result(32, ' ');
- for (int i = 0; i < 32; ++i)
- {
+ for (int i = 0; i < 32; ++i) {
result[i] = str[generator() % str.size()];
}
return result;
}
-static std::string gen_chatcmplid()
-{
- std::stringstream chatcmplid;
- chatcmplid << "chatcmpl-" << random_string();
-
- return chatcmplid.str();
-}
+static std::string gen_chatcmplid() { return "chatcmpl-" + random_string(); }
//
// other common utils
//
-static size_t common_part(const std::vector &a, const std::vector