@@ -176,6 +176,28 @@ def __init__(
176
176
177
177
if self .verbose :
178
178
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
179
+
180
+
181
+ n_vocab = self .n_vocab ()
182
+ n_ctx = self .n_ctx ()
183
+ data = (llama_cpp .llama_token_data * n_vocab )(
184
+ * [
185
+ llama_cpp .llama_token_data (
186
+ id = llama_cpp .llama_token (i ),
187
+ logit = llama_cpp .c_float (0.0 ),
188
+ p = llama_cpp .c_float (0.0 ),
189
+ )
190
+ for i in range (n_vocab )
191
+ ]
192
+ )
193
+ size = llama_cpp .c_size_t (n_vocab )
194
+ sorted = False
195
+ candidates = llama_cpp .llama_token_data_array (
196
+ data = data ,
197
+ size = size ,
198
+ sorted = sorted ,
199
+ )
200
+ self ._candidates = candidates
179
201
180
202
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
181
203
"""Tokenize a string.
@@ -296,33 +318,23 @@ def _sample(
296
318
):
297
319
assert self .ctx is not None
298
320
assert len (self .eval_logits ) > 0
299
- n_vocab = int ( llama_cpp . llama_n_vocab ( self .ctx ) )
300
- n_ctx = int ( llama_cpp . llama_n_ctx ( self .ctx ) )
321
+ n_vocab = self .n_vocab ( )
322
+ n_ctx = self .n_ctx ( )
301
323
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
302
324
last_n_tokens_size = (
303
325
llama_cpp .c_int (n_ctx )
304
326
if last_n_tokens_size .value < 0
305
327
else last_n_tokens_size
306
328
)
307
329
logits = self .eval_logits [- 1 ]
308
- nl_logit = logits [int (Llama .token_nl ())]
309
- data = (llama_cpp .llama_token_data * n_vocab )(
310
- * [
311
- llama_cpp .llama_token_data (
312
- id = llama_cpp .llama_token (i ),
313
- logit = logits [i ],
314
- p = llama_cpp .c_float (0.0 ),
315
- )
316
- for i in range (n_vocab )
317
- ]
318
- )
319
- size = llama_cpp .c_size_t (n_vocab )
320
- sorted = False
321
- candidates = llama_cpp .llama_token_data_array (
322
- data = data ,
323
- size = size ,
324
- sorted = sorted ,
325
- )
330
+ nl_logit = logits [Llama .token_nl ()]
331
+ candidates = self ._candidates
332
+ for i , logit in enumerate (logits ):
333
+ candidates .data [i ].id = llama_cpp .llama_token (i )
334
+ candidates .data [i ].logit = llama_cpp .c_float (logit )
335
+ candidates .data [i ].p = llama_cpp .c_float (0.0 )
336
+ candidates .sorted = llama_cpp .c_bool (False )
337
+ candidates .size = llama_cpp .c_size_t (n_vocab )
326
338
llama_cpp .llama_sample_repetition_penalty (
327
339
ctx = self .ctx ,
328
340
last_tokens_data = last_n_tokens_data ,
@@ -339,7 +351,7 @@ def _sample(
339
351
alpha_presence = presence_penalty ,
340
352
)
341
353
if not penalize_nl :
342
- candidates .data [int ( Llama .token_nl () )].logit = nl_logit
354
+ candidates .data [Llama .token_nl ()].logit = nl_logit
343
355
if temp .value == 0.0 :
344
356
return llama_cpp .llama_sample_token_greedy (
345
357
ctx = self .ctx ,
0 commit comments