@@ -229,8 +229,8 @@ def __init__(
229
229
n_batch : int = 512 ,
230
230
n_threads : Optional [int ] = None ,
231
231
n_threads_batch : Optional [int ] = None ,
232
- rope_freq_base : float = 10000 .0 ,
233
- rope_freq_scale : float = 1 .0 ,
232
+ rope_freq_base : float = 0 .0 ,
233
+ rope_freq_scale : float = 0 .0 ,
234
234
mul_mat_q : bool = True ,
235
235
f16_kv : bool = True ,
236
236
logits_all : bool = False ,
@@ -282,7 +282,6 @@ def __init__(
282
282
Returns:
283
283
A Llama instance.
284
284
"""
285
-
286
285
self .verbose = verbose
287
286
288
287
self .numa = numa
@@ -320,16 +319,19 @@ def __init__(
320
319
self .n_threads_batch = n_threads_batch or max (
321
320
multiprocessing .cpu_count () // 2 , 1
322
321
)
323
-
324
322
# Context Params
325
323
self .context_params = llama_cpp .llama_context_default_params ()
326
324
self .context_params .seed = seed
327
325
self .context_params .n_ctx = n_ctx
328
326
self .context_params .n_batch = self .n_batch
329
327
self .context_params .n_threads = self .n_threads
330
328
self .context_params .n_threads_batch = self .n_threads_batch
331
- self .context_params .rope_freq_base = rope_freq_base
332
- self .context_params .rope_freq_scale = rope_freq_scale
329
+ self .context_params .rope_freq_base = (
330
+ rope_freq_base if rope_freq_base != 0.0 else 0
331
+ )
332
+ self .context_params .rope_freq_scale = (
333
+ rope_freq_scale if rope_freq_scale != 0.0 else 0
334
+ )
333
335
self .context_params .mul_mat_q = mul_mat_q
334
336
self .context_params .f16_kv = f16_kv
335
337
self .context_params .logits_all = logits_all
@@ -338,7 +340,6 @@ def __init__(
338
340
# Sampling Params
339
341
self .last_n_tokens_size = last_n_tokens_size
340
342
341
-
342
343
self .cache : Optional [BaseLlamaCache ] = None
343
344
344
345
self .lora_base = lora_base
0 commit comments