@@ -568,13 +568,33 @@ def llama_model_n_embd(model: llama_model_p) -> int:
568
568
569
569
570
570
# // Get a string describing the model type
571
- # LLAMA_API int llama_model_type (const struct llama_model * model, char * buf, size_t buf_size);
572
- def llama_model_type (model : llama_model_p , buf : bytes , buf_size : c_size_t ) -> int :
573
- return _lib .llama_model_type (model , buf , buf_size )
571
+ # LLAMA_API int llama_model_desc (const struct llama_model * model, char * buf, size_t buf_size);
572
+ def llama_model_desc (model : llama_model_p , buf : bytes , buf_size : c_size_t ) -> int :
573
+ return _lib .llama_model_desc (model , buf , buf_size )
574
574
575
575
576
- _lib .llama_model_type .argtypes = [llama_model_p , c_char_p , c_size_t ]
577
- _lib .llama_model_type .restype = c_int
576
+ _lib .llama_model_desc .argtypes = [llama_model_p , c_char_p , c_size_t ]
577
+ _lib .llama_model_desc .restype = c_int
578
+
579
+
580
+ # // Returns the total size of all the tensors in the model in bytes
581
+ # LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
582
+ def llama_model_size (model : llama_model_p ) -> int :
583
+ return _lib .llama_model_size (model )
584
+
585
+
586
+ _lib .llama_model_size .argtypes = [llama_model_p ]
587
+ _lib .llama_model_size .restype = ctypes .c_uint64
588
+
589
+
590
+ # // Returns the total number of parameters in the model
591
+ # LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
592
+ def llama_model_n_params (model : llama_model_p ) -> int :
593
+ return _lib .llama_model_n_params (model )
594
+
595
+
596
+ _lib .llama_model_n_params .argtypes = [llama_model_p ]
597
+ _lib .llama_model_n_params .restype = ctypes .c_uint64
578
598
579
599
580
600
# // Returns 0 on success
@@ -1029,6 +1049,74 @@ def llama_grammar_free(grammar: llama_grammar_p):
1029
1049
_lib .llama_grammar_free .argtypes = [llama_grammar_p ]
1030
1050
_lib .llama_grammar_free .restype = None
1031
1051
1052
+ # //
1053
+ # // Beam search
1054
+ # //
1055
+
1056
+
1057
+ # struct llama_beam_view {
1058
+ # const llama_token * tokens;
1059
+ # size_t n_tokens;
1060
+ # float p; // Cumulative beam probability (renormalized relative to all beams)
1061
+ # bool eob; // Callback should set this to true when a beam is at end-of-beam.
1062
+ # };
1063
+ class llama_beam_view (ctypes .Structure ):
1064
+ _fields_ = [
1065
+ ("tokens" , llama_token_p ),
1066
+ ("n_tokens" , c_size_t ),
1067
+ ("p" , c_float ),
1068
+ ("eob" , c_bool ),
1069
+ ]
1070
+
1071
+
1072
+ # // Passed to beam_search_callback function.
1073
+ # // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
1074
+ # // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
1075
+ # // These pointers are valid only during the synchronous callback, so should not be saved.
1076
+ # struct llama_beams_state {
1077
+ # struct llama_beam_view * beam_views;
1078
+ # size_t n_beams; // Number of elements in beam_views[].
1079
+ # size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
1080
+ # bool last_call; // True iff this is the last callback invocation.
1081
+ # };
1082
+ class llama_beams_state (ctypes .Structure ):
1083
+ _fields_ = [
1084
+ ("beam_views" , POINTER (llama_beam_view )),
1085
+ ("n_beams" , c_size_t ),
1086
+ ("common_prefix_length" , c_size_t ),
1087
+ ("last_call" , c_bool ),
1088
+ ]
1089
+
1090
+
1091
+ # // Type of pointer to the beam_search_callback function.
1092
+ # // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
1093
+ # // passed back to beam_search_callback. This avoids having to use global variables in the callback.
1094
+ # typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
1095
+ llama_beam_search_callback_fn_t = ctypes .CFUNCTYPE (None , c_void_p , llama_beams_state )
1096
+
1097
+
1098
+ # /// @details Deterministically returns entire sentence constructed by a beam search.
1099
+ # /// @param ctx Pointer to the llama_context.
1100
+ # /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
1101
+ # /// @param callback_data A pointer that is simply passed back to callback.
1102
+ # /// @param n_beams Number of beams to use.
1103
+ # /// @param n_past Number of tokens already evaluated.
1104
+ # /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
1105
+ # /// @param n_threads Number of threads as passed to llama_eval().
1106
+ # LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
1107
+ def llama_beam_search (
1108
+ ctx : llama_context_p ,
1109
+ callback : "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]" , # type: ignore
1110
+ callback_data : c_void_p ,
1111
+ n_beams : c_size_t ,
1112
+ n_past : c_int ,
1113
+ n_predict : c_int ,
1114
+ n_threads : c_int ,
1115
+ ):
1116
+ return _lib .llama_beam_search (
1117
+ ctx , callback , callback_data , n_beams , n_past , n_predict , n_threads
1118
+ )
1119
+
1032
1120
1033
1121
# //
1034
1122
# // Sampling functions
0 commit comments