Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 03e2947

Browse files
committed
Fix unnecessary memory allocation while sampling
1 parent fafe471 commit 03e2947

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

llama_cpp/llama.py

+33-21
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,28 @@ def __init__(
176176

177177
if self.verbose:
178178
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
179+
180+
181+
n_vocab = self.n_vocab()
182+
n_ctx = self.n_ctx()
183+
data = (llama_cpp.llama_token_data * n_vocab)(
184+
*[
185+
llama_cpp.llama_token_data(
186+
id=llama_cpp.llama_token(i),
187+
logit=llama_cpp.c_float(0.0),
188+
p=llama_cpp.c_float(0.0),
189+
)
190+
for i in range(n_vocab)
191+
]
192+
)
193+
size = llama_cpp.c_size_t(n_vocab)
194+
sorted = False
195+
candidates = llama_cpp.llama_token_data_array(
196+
data=data,
197+
size=size,
198+
sorted=sorted,
199+
)
200+
self._candidates = candidates
179201

180202
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
181203
"""Tokenize a string.
@@ -296,33 +318,23 @@ def _sample(
296318
):
297319
assert self.ctx is not None
298320
assert len(self.eval_logits) > 0
299-
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
300-
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
321+
n_vocab = self.n_vocab()
322+
n_ctx = self.n_ctx()
301323
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
302324
last_n_tokens_size = (
303325
llama_cpp.c_int(n_ctx)
304326
if last_n_tokens_size.value < 0
305327
else last_n_tokens_size
306328
)
307329
logits = self.eval_logits[-1]
308-
nl_logit = logits[int(Llama.token_nl())]
309-
data = (llama_cpp.llama_token_data * n_vocab)(
310-
*[
311-
llama_cpp.llama_token_data(
312-
id=llama_cpp.llama_token(i),
313-
logit=logits[i],
314-
p=llama_cpp.c_float(0.0),
315-
)
316-
for i in range(n_vocab)
317-
]
318-
)
319-
size = llama_cpp.c_size_t(n_vocab)
320-
sorted = False
321-
candidates = llama_cpp.llama_token_data_array(
322-
data=data,
323-
size=size,
324-
sorted=sorted,
325-
)
330+
nl_logit = logits[Llama.token_nl()]
331+
candidates = self._candidates
332+
for i, logit in enumerate(logits):
333+
candidates.data[i].id = llama_cpp.llama_token(i)
334+
candidates.data[i].logit = llama_cpp.c_float(logit)
335+
candidates.data[i].p = llama_cpp.c_float(0.0)
336+
candidates.sorted = llama_cpp.c_bool(False)
337+
candidates.size = llama_cpp.c_size_t(n_vocab)
326338
llama_cpp.llama_sample_repetition_penalty(
327339
ctx=self.ctx,
328340
last_tokens_data=last_n_tokens_data,
@@ -339,7 +351,7 @@ def _sample(
339351
alpha_presence=presence_penalty,
340352
)
341353
if not penalize_nl:
342-
candidates.data[int(Llama.token_nl())].logit = nl_logit
354+
candidates.data[Llama.token_nl()].logit = nl_logit
343355
if temp.value == 0.0:
344356
return llama_cpp.llama_sample_token_greedy(
345357
ctx=self.ctx,

0 commit comments

Comments
 (0)