@@ -399,7 +399,7 @@ class llama_token_data_array(ctypes.Structure):
399
399
# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
400
400
# // - pos : the positions of the respective token in the sequence
401
401
# // - seq_id : the sequence to which the respective token belongs
402
- # // - logits : if zero, the logits for the respective token will not be output
402
+ # // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
403
403
# //
404
404
# typedef struct llama_batch {
405
405
# int32_t n_tokens;
@@ -409,7 +409,7 @@ class llama_token_data_array(ctypes.Structure):
409
409
# llama_pos * pos;
410
410
# int32_t * n_seq_id;
411
411
# llama_seq_id ** seq_id;
412
- # int8_t * logits;
412
+ # int8_t * logits; // TODO: rename this to "output"
413
413
414
414
415
415
# // NOTE: helpers for smooth API transition - can be deprecated in the future
@@ -572,7 +572,7 @@ class llama_model_params(ctypes.Structure):
572
572
573
573
# // Keep the booleans together to avoid misalignment during copy-by-value.
574
574
# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
575
- # bool embedding ; // embedding mode only
575
+ # bool embeddings ; // if true, extract embeddings (together with logits)
576
576
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
577
577
578
578
# // Abort callback
@@ -605,7 +605,7 @@ class llama_context_params(ctypes.Structure):
605
605
type_k (int): data type for K cache
606
606
type_v (int): data type for V cache
607
607
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
608
- embedding (bool): embedding mode only
608
+ embeddings (bool): if true, extract embeddings (together with logits)
609
609
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
610
610
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
611
611
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
@@ -632,7 +632,7 @@ class llama_context_params(ctypes.Structure):
632
632
("type_k" , ctypes .c_int ),
633
633
("type_v" , ctypes .c_int ),
634
634
("logits_all" , ctypes .c_bool ),
635
- ("embedding " , ctypes .c_bool ),
635
+ ("embeddings " , ctypes .c_bool ),
636
636
("offload_kqv" , ctypes .c_bool ),
637
637
("abort_callback" , ggml_abort_callback ),
638
638
("abort_callback_data" , ctypes .c_void_p ),
@@ -1774,8 +1774,8 @@ def llama_get_logits_ith(
1774
1774
...
1775
1775
1776
1776
1777
- # Get the embeddings for the input
1778
- # shape: [n_embd] (1-dimensional)
1777
+ # // Get all output token embeddings
1778
+ # // shape: [n_tokens* n_embd] (1-dimensional)
1779
1779
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
1780
1780
@ctypes_function (
1781
1781
"llama_get_embeddings" , [llama_context_p_ctypes ], ctypes .POINTER (ctypes .c_float )
@@ -1786,8 +1786,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
1786
1786
...
1787
1787
1788
1788
1789
- # // Get the embeddings for the ith sequence
1789
+ # // Get the embeddings for the ith token
1790
1790
# // llama_get_embeddings(ctx) + i*n_embd
1791
+ # // shape: [n_embd] (1-dimensional)
1791
1792
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
1792
1793
@ctypes_function (
1793
1794
"llama_get_embeddings_ith" ,
@@ -1802,6 +1803,23 @@ def llama_get_embeddings_ith(
1802
1803
...
1803
1804
1804
1805
1806
+ # // Get the embeddings for a sequence id
1807
+ # // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1808
+ # // shape: [n_embd] (1-dimensional)
1809
+ # LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1810
+ @ctypes_function (
1811
+ "llama_get_embeddings_seq" ,
1812
+ [llama_context_p_ctypes , llama_seq_id ],
1813
+ ctypes .POINTER (ctypes .c_float ),
1814
+ )
1815
+ def llama_get_embeddings_seq (
1816
+ ctx : llama_context_p , seq_id : Union [llama_seq_id , int ], /
1817
+ ) -> CtypesArray [ctypes .c_float ]:
1818
+ """Get the embeddings for a sequence id
1819
+ Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1820
+ shape: [n_embd] (1-dimensional)"""
1821
+ ...
1822
+
1805
1823
# //
1806
1824
# // Vocab
1807
1825
# //
0 commit comments