Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 93dc56a

Browse files
committed
Update llama.cpp
1 parent 87a6e57 commit 93dc56a

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

llama_cpp/llama.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __init__(
293293
self.context_params.logits_all = (
294294
logits_all if draft_model is None else True
295295
) # Must be set to True for speculative decoding
296-
self.context_params.embedding = embedding
296+
self.context_params.embeddings = embedding # TODO: Rename to embeddings
297297
self.context_params.offload_kqv = offload_kqv
298298

299299
# Sampling Params
@@ -787,7 +787,7 @@ def embed(
787787
n_embd = self.n_embd()
788788
n_batch = self.n_batch
789789

790-
if self.context_params.embedding == False:
790+
if self.context_params.embeddings == False:
791791
raise RuntimeError(
792792
"Llama model must be created with embedding=True to call this method"
793793
)
@@ -1725,7 +1725,7 @@ def __getstate__(self):
17251725
yarn_beta_slow=self.context_params.yarn_beta_slow,
17261726
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
17271727
logits_all=self.context_params.logits_all,
1728-
embedding=self.context_params.embedding,
1728+
embedding=self.context_params.embeddings,
17291729
# Sampling Params
17301730
last_n_tokens_size=self.last_n_tokens_size,
17311731
# LoRA Params

llama_cpp/llama_cpp.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class llama_token_data_array(ctypes.Structure):
399399
# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
400400
# // - pos : the positions of the respective token in the sequence
401401
# // - seq_id : the sequence to which the respective token belongs
402-
# // - logits : if zero, the logits for the respective token will not be output
402+
# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
403403
# //
404404
# typedef struct llama_batch {
405405
# int32_t n_tokens;
@@ -409,7 +409,7 @@ class llama_token_data_array(ctypes.Structure):
409409
# llama_pos * pos;
410410
# int32_t * n_seq_id;
411411
# llama_seq_id ** seq_id;
412-
# int8_t * logits;
412+
# int8_t * logits; // TODO: rename this to "output"
413413

414414

415415
# // NOTE: helpers for smooth API transition - can be deprecated in the future
@@ -572,7 +572,7 @@ class llama_model_params(ctypes.Structure):
572572

573573
# // Keep the booleans together to avoid misalignment during copy-by-value.
574574
# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
575-
# bool embedding; // embedding mode only
575+
# bool embeddings; // if true, extract embeddings (together with logits)
576576
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
577577

578578
# // Abort callback
@@ -605,7 +605,7 @@ class llama_context_params(ctypes.Structure):
605605
type_k (int): data type for K cache
606606
type_v (int): data type for V cache
607607
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
608-
embedding (bool): embedding mode only
608+
embeddings (bool): if true, extract embeddings (together with logits)
609609
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
610610
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
611611
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
@@ -632,7 +632,7 @@ class llama_context_params(ctypes.Structure):
632632
("type_k", ctypes.c_int),
633633
("type_v", ctypes.c_int),
634634
("logits_all", ctypes.c_bool),
635-
("embedding", ctypes.c_bool),
635+
("embeddings", ctypes.c_bool),
636636
("offload_kqv", ctypes.c_bool),
637637
("abort_callback", ggml_abort_callback),
638638
("abort_callback_data", ctypes.c_void_p),
@@ -1774,8 +1774,8 @@ def llama_get_logits_ith(
17741774
...
17751775

17761776

1777-
# Get the embeddings for the input
1778-
# shape: [n_embd] (1-dimensional)
1777+
# // Get all output token embeddings
1778+
# // shape: [n_tokens*n_embd] (1-dimensional)
17791779
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
17801780
@ctypes_function(
17811781
"llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
@@ -1786,8 +1786,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
17861786
...
17871787

17881788

1789-
# // Get the embeddings for the ith sequence
1789+
# // Get the embeddings for the ith token
17901790
# // llama_get_embeddings(ctx) + i*n_embd
1791+
# // shape: [n_embd] (1-dimensional)
17911792
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
17921793
@ctypes_function(
17931794
"llama_get_embeddings_ith",
@@ -1802,6 +1803,23 @@ def llama_get_embeddings_ith(
18021803
...
18031804

18041805

1806+
# // Get the embeddings for a sequence id
1807+
# // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1808+
# // shape: [n_embd] (1-dimensional)
1809+
# LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1810+
@ctypes_function(
1811+
"llama_get_embeddings_seq",
1812+
[llama_context_p_ctypes, llama_seq_id],
1813+
ctypes.POINTER(ctypes.c_float),
1814+
)
1815+
def llama_get_embeddings_seq(
1816+
ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /
1817+
) -> CtypesArray[ctypes.c_float]:
1818+
"""Get the embeddings for a sequence id
1819+
Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1820+
shape: [n_embd] (1-dimensional)"""
1821+
...
1822+
18051823
# //
18061824
# // Vocab
18071825
# //

vendor/llama.cpp

0 commit comments

Comments
 (0)