@@ -255,28 +255,28 @@ def __init__(
255
255
for i , (k , v ) in enumerate (kv_overrides .items ()):
256
256
self ._kv_overrides_array [i ].key = k .encode ("utf-8" )
257
257
if isinstance (v , bool ):
258
- self ._kv_overrides_array [i ]. tag = (
259
- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_BOOL
260
- )
258
+ self ._kv_overrides_array [
259
+ i
260
+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_BOOL
261
261
self ._kv_overrides_array [i ].value .val_bool = v
262
262
elif isinstance (v , int ):
263
- self ._kv_overrides_array [i ]. tag = (
264
- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_INT
265
- )
263
+ self ._kv_overrides_array [
264
+ i
265
+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_INT
266
266
self ._kv_overrides_array [i ].value .val_i64 = v
267
267
elif isinstance (v , float ):
268
- self ._kv_overrides_array [i ]. tag = (
269
- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_FLOAT
270
- )
268
+ self ._kv_overrides_array [
269
+ i
270
+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_FLOAT
271
271
self ._kv_overrides_array [i ].value .val_f64 = v
272
272
elif isinstance (v , str ): # type: ignore
273
273
v_bytes = v .encode ("utf-8" )
274
274
if len (v_bytes ) > 128 : # TODO: Make this a constant
275
275
raise ValueError (f"Value for { k } is too long: { v } " )
276
276
v_bytes = v_bytes .ljust (128 , b"\0 " )
277
- self ._kv_overrides_array [i ]. tag = (
278
- llama_cpp . LLAMA_KV_OVERRIDE_TYPE_STR
279
- )
277
+ self ._kv_overrides_array [
278
+ i
279
+ ]. tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_STR
280
280
# copy min(v_bytes, 128) to str_value
281
281
address = typing .cast (
282
282
int ,
@@ -292,9 +292,9 @@ def __init__(
292
292
else :
293
293
raise ValueError (f"Unknown value type for { k } : { v } " )
294
294
295
- self ._kv_overrides_array [- 1 ]. key = (
296
- b" \0 " # ensure sentinel element is zeroed
297
- )
295
+ self ._kv_overrides_array [
296
+ - 1
297
+ ]. key = b" \0 " # ensure sentinel element is zeroed
298
298
self .model_params .kv_overrides = self ._kv_overrides_array
299
299
300
300
self .n_batch = min (n_ctx , n_batch ) # ???
@@ -431,9 +431,9 @@ def free_lora_adapter():
431
431
432
432
self .chat_format = chat_format
433
433
self .chat_handler = chat_handler
434
- self ._chat_handlers : Dict [str , llama_chat_format . LlamaChatCompletionHandler ] = (
435
- {}
436
- )
434
+ self ._chat_handlers : Dict [
435
+ str , llama_chat_format . LlamaChatCompletionHandler
436
+ ] = {}
437
437
438
438
self .draft_model = draft_model
439
439
@@ -580,7 +580,10 @@ def tokenize(
580
580
return self .tokenizer_ .tokenize (text , add_bos , special )
581
581
582
582
def detokenize (
583
- self , tokens : List [int ], prev_tokens : Optional [List [int ]] = None , special : bool = False
583
+ self ,
584
+ tokens : List [int ],
585
+ prev_tokens : Optional [List [int ]] = None ,
586
+ special : bool = False ,
584
587
) -> bytes :
585
588
"""Detokenize a list of tokens.
586
589
@@ -592,7 +595,9 @@ def detokenize(
592
595
Returns:
593
596
The detokenized string.
594
597
"""
595
- return self .tokenizer_ .detokenize (tokens , prev_tokens = prev_tokens , special = special )
598
+ return self .tokenizer_ .detokenize (
599
+ tokens , prev_tokens = prev_tokens , special = special
600
+ )
596
601
597
602
def set_cache (self , cache : Optional [BaseLlamaCache ]):
598
603
"""Set the cache.
@@ -681,12 +686,16 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
681
686
recarray = np .recarray (
682
687
shape = (size ,),
683
688
dtype = np .dtype (
684
- [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
689
+ [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )],
690
+ align = True ,
691
+ ),
692
+ buf = (llama_cpp .llama_token_data * size ).from_address (
693
+ data_soa_address
685
694
),
686
- buf = (llama_cpp .llama_token_data * size ).from_address (data_soa_address ),
687
695
)
688
696
for logit_processor in logits_processor :
689
697
recarray .logit [:] = logit_processor (self ._input_ids , recarray .logit )
698
+
690
699
sampler .add_custom (apply_func )
691
700
692
701
sampler .add_penalties (
@@ -698,7 +707,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
698
707
penalty_freq = frequency_penalty ,
699
708
penalty_present = presence_penalty ,
700
709
penalize_nl = penalize_nl ,
701
- ignore_eos = False
710
+ ignore_eos = False ,
702
711
)
703
712
704
713
if grammar is not None :
@@ -841,22 +850,22 @@ def generate(
841
850
# Reset mirostat sampling
842
851
self ._mirostat_mu = ctypes .c_float (2.0 * mirostat_tau )
843
852
self ._sampler = self ._init_sampler (
844
- top_k = top_k ,
845
- top_p = top_p ,
846
- min_p = min_p ,
847
- typical_p = typical_p ,
848
- temp = temp ,
849
- repeat_penalty = repeat_penalty ,
850
- frequency_penalty = frequency_penalty ,
851
- presence_penalty = presence_penalty ,
852
- tfs_z = tfs_z ,
853
- mirostat_mode = mirostat_mode ,
854
- mirostat_tau = mirostat_tau ,
855
- mirostat_eta = mirostat_eta ,
856
- penalize_nl = penalize_nl ,
857
- logits_processor = logits_processor ,
858
- grammar = grammar ,
859
- seed = seed ,
853
+ top_k = top_k ,
854
+ top_p = top_p ,
855
+ min_p = min_p ,
856
+ typical_p = typical_p ,
857
+ temp = temp ,
858
+ repeat_penalty = repeat_penalty ,
859
+ frequency_penalty = frequency_penalty ,
860
+ presence_penalty = presence_penalty ,
861
+ tfs_z = tfs_z ,
862
+ mirostat_mode = mirostat_mode ,
863
+ mirostat_tau = mirostat_tau ,
864
+ mirostat_eta = mirostat_eta ,
865
+ penalize_nl = penalize_nl ,
866
+ logits_processor = logits_processor ,
867
+ grammar = grammar ,
868
+ seed = seed ,
860
869
)
861
870
862
871
# Check for kv cache prefix match
@@ -872,8 +881,11 @@ def generate(
872
881
tokens = tokens [longest_prefix :]
873
882
self .n_tokens = longest_prefix
874
883
if self .verbose :
875
- print (f"Llama.generate: { longest_prefix } prefix-match hit, "
876
- f"remaining { len (tokens )} prompt tokens to eval" , file = sys .stderr )
884
+ print (
885
+ f"Llama.generate: { longest_prefix } prefix-match hit, "
886
+ f"remaining { len (tokens )} prompt tokens to eval" ,
887
+ file = sys .stderr ,
888
+ )
877
889
878
890
# Reset the model state
879
891
if reset :
@@ -1032,7 +1044,9 @@ def decode_batch(seq_sizes: List[int]):
1032
1044
for j in range (size )
1033
1045
]
1034
1046
if normalize :
1035
- embedding = [internals .normalize_embedding (e ) for e in embedding ]
1047
+ embedding = [
1048
+ internals .normalize_embedding (e ) for e in embedding
1049
+ ]
1036
1050
data .append (embedding )
1037
1051
pos += size
1038
1052
else :
0 commit comments