@@ -230,8 +230,8 @@ def __init__(
230
230
n_batch : int = 512 ,
231
231
n_threads : Optional [int ] = None ,
232
232
n_threads_batch : Optional [int ] = None ,
233
- rope_freq_base : float = 10000 .0 ,
234
- rope_freq_scale : float = 1 .0 ,
233
+ rope_freq_base : float = 0 .0 ,
234
+ rope_freq_scale : float = 0 .0 ,
235
235
mul_mat_q : bool = True ,
236
236
f16_kv : bool = True ,
237
237
logits_all : bool = False ,
@@ -286,7 +286,6 @@ def __init__(
286
286
Returns:
287
287
A Llama instance.
288
288
"""
289
-
290
289
self .verbose = verbose
291
290
292
291
self .numa = numa
@@ -324,16 +323,19 @@ def __init__(
324
323
self .n_threads_batch = n_threads_batch or max (
325
324
multiprocessing .cpu_count () // 2 , 1
326
325
)
327
-
328
326
# Context Params
329
327
self .context_params = llama_cpp .llama_context_default_params ()
330
328
self .context_params .seed = seed
331
329
self .context_params .n_ctx = n_ctx
332
330
self .context_params .n_batch = self .n_batch
333
331
self .context_params .n_threads = self .n_threads
334
332
self .context_params .n_threads_batch = self .n_threads_batch
335
- self .context_params .rope_freq_base = rope_freq_base
336
- self .context_params .rope_freq_scale = rope_freq_scale
333
+ self .context_params .rope_freq_base = (
334
+ rope_freq_base if rope_freq_base != 0.0 else 0
335
+ )
336
+ self .context_params .rope_freq_scale = (
337
+ rope_freq_scale if rope_freq_scale != 0.0 else 0
338
+ )
337
339
self .context_params .mul_mat_q = mul_mat_q
338
340
self .context_params .f16_kv = f16_kv
339
341
self .context_params .logits_all = logits_all
@@ -342,7 +344,6 @@ def __init__(
342
344
# Sampling Params
343
345
self .last_n_tokens_size = last_n_tokens_size
344
346
345
-
346
347
self .cache : Optional [BaseLlamaCache ] = None
347
348
348
349
self .lora_base = lora_base
0 commit comments