Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0b64530

Browse files
authored
Merge branch 'main' into patch-3
2 parents d8dbdfd + 9b64bb5 commit 0b64530

File tree

8 files changed

+300
-156
lines changed

8 files changed

+300
-156
lines changed

llama_cpp/_internals.py

+50-16
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def n_params(self) -> int:
100100
def get_tensor(self, name: str) -> ctypes.c_void_p:
101101
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
102102

103-
104103
# Vocab
105104

106105
def token_get_text(self, token: int) -> str:
@@ -460,9 +459,7 @@ def __init__(
460459
self.verbose = verbose
461460
self._exit_stack = ExitStack()
462461

463-
batch = llama_cpp.llama_batch_init(
464-
self._n_tokens, self.embd, self.n_seq_max
465-
)
462+
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
466463

467464
if batch is None:
468465
raise ValueError("Failed to create llama_batch")
@@ -541,6 +538,7 @@ def copy_logits(self, logits: npt.NDArray[np.single]):
541538

542539
# Embedding functions
543540

541+
544542
def normalize_embedding(embedding):
545543
norm = float(np.linalg.norm(embedding))
546544
if norm == 0.0:
@@ -713,11 +711,17 @@ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
713711
import ctypes
714712
import llama_cpp
715713

714+
716715
class CustomSampler:
717-
def __init__(self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]):
716+
def __init__(
717+
self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]
718+
):
718719
self.apply_func = apply_func
719720

720-
def apply_wrapper(sampler: llama_cpp.llama_sampler_p, cur_p: llama_cpp.llama_token_data_array_p):
721+
def apply_wrapper(
722+
sampler: llama_cpp.llama_sampler_p,
723+
cur_p: llama_cpp.llama_token_data_array_p,
724+
):
721725
self.apply_func(cur_p)
722726

723727
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
@@ -740,6 +744,7 @@ def free_wrapper(sampler: llama_cpp.llama_sampler_p):
740744
def get_sampler(self) -> llama_cpp.llama_sampler_p:
741745
return ctypes.pointer(self.sampler)
742746

747+
743748
class LlamaSampler:
744749
def __init__(self):
745750
params = llama_cpp.llama_sampler_chain_params()
@@ -788,33 +793,62 @@ def add_temp_ext(self, t: float, delta: float, exponent: float):
788793
self._add_sampler(sampler)
789794

790795
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
791-
sampler = llama_cpp.llama_sampler_init_mirostat(
792-
n_vocab, seed, tau, eta, m
793-
)
796+
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
794797
self._add_sampler(sampler)
795798

796799
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
797800
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
798801
self._add_sampler(sampler)
799802

800803
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
801-
sampler = llama_cpp.llama_sampler_init_grammar(model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8"))
804+
sampler = llama_cpp.llama_sampler_init_grammar(
805+
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
806+
)
802807
self._add_sampler(sampler)
803808

804-
def add_penalties(self, n_vocab: int, special_eos_id: int, linefeed_id: int, penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, penalize_nl: bool, ignore_eos: bool):
805-
sampler = llama_cpp.llama_sampler_init_penalties(n_vocab, special_eos_id, linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos)
809+
def add_penalties(
810+
self,
811+
n_vocab: int,
812+
special_eos_id: int,
813+
linefeed_id: int,
814+
penalty_last_n: int,
815+
penalty_repeat: float,
816+
penalty_freq: float,
817+
penalty_present: float,
818+
penalize_nl: bool,
819+
ignore_eos: bool,
820+
):
821+
sampler = llama_cpp.llama_sampler_init_penalties(
822+
n_vocab,
823+
special_eos_id,
824+
linefeed_id,
825+
penalty_last_n,
826+
penalty_repeat,
827+
penalty_freq,
828+
penalty_present,
829+
penalize_nl,
830+
ignore_eos,
831+
)
806832
self._add_sampler(sampler)
807833

808-
def init_logit_bias(self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p):
809-
sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, n_logit_bias, logit_bias)
834+
def init_logit_bias(
835+
self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
836+
):
837+
sampler = llama_cpp.llama_sampler_init_logit_bias(
838+
n_vocab, n_logit_bias, logit_bias
839+
)
810840
self._add_sampler(sampler)
811841

812-
def add_custom(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]):
842+
def add_custom(
843+
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
844+
):
813845
custom_sampler = CustomSampler(apply_func)
814846
sampler = custom_sampler.get_sampler()
815847
self._add_sampler(sampler)
816848
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
817-
self.custom_samplers.append((llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler))
849+
self.custom_samplers.append(
850+
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
851+
)
818852

819853
def _add_sampler(self, sampler: llama_cpp.llama_sampler_p):
820854
assert self.sampler is not None

llama_cpp/llama.py

+56-42
Original file line numberDiff line numberDiff line change
@@ -255,28 +255,28 @@ def __init__(
255255
for i, (k, v) in enumerate(kv_overrides.items()):
256256
self._kv_overrides_array[i].key = k.encode("utf-8")
257257
if isinstance(v, bool):
258-
self._kv_overrides_array[i].tag = (
259-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
260-
)
258+
self._kv_overrides_array[
259+
i
260+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
261261
self._kv_overrides_array[i].value.val_bool = v
262262
elif isinstance(v, int):
263-
self._kv_overrides_array[i].tag = (
264-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
265-
)
263+
self._kv_overrides_array[
264+
i
265+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
266266
self._kv_overrides_array[i].value.val_i64 = v
267267
elif isinstance(v, float):
268-
self._kv_overrides_array[i].tag = (
269-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
270-
)
268+
self._kv_overrides_array[
269+
i
270+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
271271
self._kv_overrides_array[i].value.val_f64 = v
272272
elif isinstance(v, str): # type: ignore
273273
v_bytes = v.encode("utf-8")
274274
if len(v_bytes) > 128: # TODO: Make this a constant
275275
raise ValueError(f"Value for {k} is too long: {v}")
276276
v_bytes = v_bytes.ljust(128, b"\0")
277-
self._kv_overrides_array[i].tag = (
278-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
279-
)
277+
self._kv_overrides_array[
278+
i
279+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
280280
# copy min(v_bytes, 128) to str_value
281281
address = typing.cast(
282282
int,
@@ -292,9 +292,9 @@ def __init__(
292292
else:
293293
raise ValueError(f"Unknown value type for {k}: {v}")
294294

295-
self._kv_overrides_array[-1].key = (
296-
b"\0" # ensure sentinel element is zeroed
297-
)
295+
self._kv_overrides_array[
296+
-1
297+
].key = b"\0" # ensure sentinel element is zeroed
298298
self.model_params.kv_overrides = self._kv_overrides_array
299299

300300
self.n_batch = min(n_ctx, n_batch) # ???
@@ -431,9 +431,9 @@ def free_lora_adapter():
431431

432432
self.chat_format = chat_format
433433
self.chat_handler = chat_handler
434-
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = (
435-
{}
436-
)
434+
self._chat_handlers: Dict[
435+
str, llama_chat_format.LlamaChatCompletionHandler
436+
] = {}
437437

438438
self.draft_model = draft_model
439439

@@ -580,7 +580,10 @@ def tokenize(
580580
return self.tokenizer_.tokenize(text, add_bos, special)
581581

582582
def detokenize(
583-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
583+
self,
584+
tokens: List[int],
585+
prev_tokens: Optional[List[int]] = None,
586+
special: bool = False,
584587
) -> bytes:
585588
"""Detokenize a list of tokens.
586589
@@ -592,7 +595,9 @@ def detokenize(
592595
Returns:
593596
The detokenized string.
594597
"""
595-
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special)
598+
return self.tokenizer_.detokenize(
599+
tokens, prev_tokens=prev_tokens, special=special
600+
)
596601

597602
def set_cache(self, cache: Optional[BaseLlamaCache]):
598603
"""Set the cache.
@@ -681,12 +686,16 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
681686
recarray = np.recarray(
682687
shape=(size,),
683688
dtype=np.dtype(
684-
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
689+
[("id", np.intc), ("logit", np.single), ("p", np.single)],
690+
align=True,
691+
),
692+
buf=(llama_cpp.llama_token_data * size).from_address(
693+
data_soa_address
685694
),
686-
buf=(llama_cpp.llama_token_data * size).from_address(data_soa_address),
687695
)
688696
for logit_processor in logits_processor:
689697
recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
698+
690699
sampler.add_custom(apply_func)
691700

692701
sampler.add_penalties(
@@ -698,7 +707,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
698707
penalty_freq=frequency_penalty,
699708
penalty_present=presence_penalty,
700709
penalize_nl=penalize_nl,
701-
ignore_eos=False
710+
ignore_eos=False,
702711
)
703712

704713
if grammar is not None:
@@ -841,22 +850,22 @@ def generate(
841850
# Reset mirostat sampling
842851
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
843852
self._sampler = self._init_sampler(
844-
top_k=top_k,
845-
top_p=top_p,
846-
min_p=min_p,
847-
typical_p=typical_p,
848-
temp=temp,
849-
repeat_penalty=repeat_penalty,
850-
frequency_penalty=frequency_penalty,
851-
presence_penalty=presence_penalty,
852-
tfs_z=tfs_z,
853-
mirostat_mode=mirostat_mode,
854-
mirostat_tau=mirostat_tau,
855-
mirostat_eta=mirostat_eta,
856-
penalize_nl=penalize_nl,
857-
logits_processor=logits_processor,
858-
grammar=grammar,
859-
seed=seed,
853+
top_k=top_k,
854+
top_p=top_p,
855+
min_p=min_p,
856+
typical_p=typical_p,
857+
temp=temp,
858+
repeat_penalty=repeat_penalty,
859+
frequency_penalty=frequency_penalty,
860+
presence_penalty=presence_penalty,
861+
tfs_z=tfs_z,
862+
mirostat_mode=mirostat_mode,
863+
mirostat_tau=mirostat_tau,
864+
mirostat_eta=mirostat_eta,
865+
penalize_nl=penalize_nl,
866+
logits_processor=logits_processor,
867+
grammar=grammar,
868+
seed=seed,
860869
)
861870

862871
# Check for kv cache prefix match
@@ -872,8 +881,11 @@ def generate(
872881
tokens = tokens[longest_prefix:]
873882
self.n_tokens = longest_prefix
874883
if self.verbose:
875-
print(f"Llama.generate: {longest_prefix} prefix-match hit, "
876-
f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr)
884+
print(
885+
f"Llama.generate: {longest_prefix} prefix-match hit, "
886+
f"remaining {len(tokens)} prompt tokens to eval",
887+
file=sys.stderr,
888+
)
877889

878890
# Reset the model state
879891
if reset:
@@ -1032,7 +1044,9 @@ def decode_batch(seq_sizes: List[int]):
10321044
for j in range(size)
10331045
]
10341046
if normalize:
1035-
embedding = [internals.normalize_embedding(e) for e in embedding]
1047+
embedding = [
1048+
internals.normalize_embedding(e) for e in embedding
1049+
]
10361050
data.append(embedding)
10371051
pos += size
10381052
else:

llama_cpp/llama_cache.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ class LlamaRAMCache(BaseLlamaCache):
5252
def __init__(self, capacity_bytes: int = (2 << 30)):
5353
super().__init__(capacity_bytes)
5454
self.capacity_bytes = capacity_bytes
55-
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = (
56-
OrderedDict()
57-
)
55+
self.cache_state: OrderedDict[
56+
Tuple[int, ...], "llama_cpp.llama.LlamaState"
57+
] = OrderedDict()
5858

5959
@property
6060
def cache_size(self):

0 commit comments

Comments
 (0)