@@ -399,7 +399,7 @@ class llama_token_data_array(ctypes.Structure):
399399# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
400400# // - pos : the positions of the respective token in the sequence
401401# // - seq_id : the sequence to which the respective token belongs
402- # // - logits : if zero, the logits for the respective token will not be output
402+ # // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
403403# //
404404# typedef struct llama_batch {
405405# int32_t n_tokens;
@@ -409,7 +409,7 @@ class llama_token_data_array(ctypes.Structure):
409409# llama_pos * pos;
410410# int32_t * n_seq_id;
411411# llama_seq_id ** seq_id;
412- # int8_t * logits;
412+ # int8_t * logits; // TODO: rename this to "output"
413413
414414
415415# // NOTE: helpers for smooth API transition - can be deprecated in the future
@@ -572,7 +572,7 @@ class llama_model_params(ctypes.Structure):
572572
573573# // Keep the booleans together to avoid misalignment during copy-by-value.
574574# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
575- # bool embedding ; // embedding mode only
575+ # bool embeddings ; // if true, extract embeddings (together with logits)
576576# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
577577
578578# // Abort callback
@@ -605,7 +605,7 @@ class llama_context_params(ctypes.Structure):
605605 type_k (int): data type for K cache
606606 type_v (int): data type for V cache
607607 logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
608- embedding (bool): embedding mode only
608+ embeddings (bool): if true, extract embeddings (together with logits)
609609 offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
610610 abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
611611 abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
@@ -632,7 +632,7 @@ class llama_context_params(ctypes.Structure):
632632 ("type_k" , ctypes .c_int ),
633633 ("type_v" , ctypes .c_int ),
634634 ("logits_all" , ctypes .c_bool ),
635- ("embedding " , ctypes .c_bool ),
635+ ("embeddings " , ctypes .c_bool ),
636636 ("offload_kqv" , ctypes .c_bool ),
637637 ("abort_callback" , ggml_abort_callback ),
638638 ("abort_callback_data" , ctypes .c_void_p ),
@@ -1774,8 +1774,8 @@ def llama_get_logits_ith(
17741774 ...
17751775
17761776
1777- # Get the embeddings for the input
1778- # shape: [n_embd] (1-dimensional)
1777+ # // Get all output token embeddings
1778+ # // shape: [n_tokens* n_embd] (1-dimensional)
17791779# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
17801780@ctypes_function (
17811781 "llama_get_embeddings" , [llama_context_p_ctypes ], ctypes .POINTER (ctypes .c_float )
@@ -1786,8 +1786,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
17861786 ...
17871787
17881788
1789- # // Get the embeddings for the ith sequence
1789+ # // Get the embeddings for the ith token
17901790# // llama_get_embeddings(ctx) + i*n_embd
1791+ # // shape: [n_embd] (1-dimensional)
17911792# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
17921793@ctypes_function (
17931794 "llama_get_embeddings_ith" ,
@@ -1802,6 +1803,23 @@ def llama_get_embeddings_ith(
18021803 ...
18031804
18041805
1806+ # // Get the embeddings for a sequence id
1807+ # // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1808+ # // shape: [n_embd] (1-dimensional)
1809+ # LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1810+ @ctypes_function (
1811+ "llama_get_embeddings_seq" ,
1812+ [llama_context_p_ctypes , llama_seq_id ],
1813+ ctypes .POINTER (ctypes .c_float ),
1814+ )
1815+ def llama_get_embeddings_seq (
1816+ ctx : llama_context_p , seq_id : Union [llama_seq_id , int ], /
1817+ ) -> CtypesArray [ctypes .c_float ]:
1818+ """Get the embeddings for a sequence id
1819+ Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1820+ shape: [n_embd] (1-dimensional)"""
1821+ ...
1822+
18051823# //
18061824# // Vocab
18071825# //
0 commit comments