From 813da0175eae5edff8d2cc0887a14ee69185051d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 29 Sep 2023 14:06:24 -0400 Subject: [PATCH 1/5] Fix rope scale with backwards compatibility --- llama_cpp/llama.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 54424cb91..85f722eb7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -282,7 +282,6 @@ def __init__( Returns: A Llama instance. """ - self.verbose = verbose self.numa = numa @@ -320,7 +319,6 @@ def __init__( self.n_threads_batch = n_threads_batch or max( multiprocessing.cpu_count() // 2, 1 ) - # Context Params self.context_params = llama_cpp.llama_context_default_params() self.context_params.seed = seed @@ -328,8 +326,16 @@ def __init__( self.context_params.n_batch = self.n_batch self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch - self.context_params.rope_freq_base = rope_freq_base - self.context_params.rope_freq_scale = rope_freq_scale + self.context_params.rope_freq_base = ( + rope_freq_base + if rope_freq_base == 10000.0 and rope_freq_scale == 0.0 + else 0 + ) + self.context_params.rope_freq_scale = ( + rope_freq_scale + if rope_freq_base == 10000.0 and rope_freq_scale == 0.0 + else 0 + ) self.context_params.mul_mat_q = mul_mat_q self.context_params.f16_kv = f16_kv self.context_params.logits_all = logits_all @@ -338,7 +344,6 @@ def __init__( # Sampling Params self.last_n_tokens_size = last_n_tokens_size - self.cache: Optional[BaseLlamaCache] = None self.lora_base = lora_base From 9eaffdb577a1abb61583c4914e58a268f056f562 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 29 Sep 2023 14:07:47 -0400 Subject: [PATCH 2/5] Fix defaults --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 85f722eb7..a879e60b8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -328,12 +328,12 @@ def __init__( self.context_params.n_threads_batch = self.n_threads_batch self.context_params.rope_freq_base = ( rope_freq_base - if rope_freq_base == 10000.0 and rope_freq_scale == 0.0 + if rope_freq_base == 10000.0 and rope_freq_scale == 1.0 else 0 ) self.context_params.rope_freq_scale = ( rope_freq_scale - if rope_freq_base == 10000.0 and rope_freq_scale == 0.0 + if rope_freq_base == 10000.0 and rope_freq_scale == 1.0 else 0 ) self.context_params.mul_mat_q = mul_mat_q From 9997432e97bf5f93a4c949a677db33e785896787 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 29 Sep 2023 14:11:25 -0400 Subject: [PATCH 3/5] Fix op --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a879e60b8..0deef7ab4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -328,12 +328,12 @@ def __init__( self.context_params.n_threads_batch = self.n_threads_batch self.context_params.rope_freq_base = ( rope_freq_base - if rope_freq_base == 10000.0 and rope_freq_scale == 1.0 + if rope_freq_base != 10000.0 and rope_freq_scale != 1.0 else 0 ) self.context_params.rope_freq_scale = ( rope_freq_scale - if rope_freq_base == 10000.0 and rope_freq_scale == 1.0 + if rope_freq_base != 10000.0 and rope_freq_scale != 1.0 else 0 ) self.context_params.mul_mat_q = mul_mat_q From 0bd22a4bfde87843e009839fcc86169a698248fa Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 29 Sep 2023 14:21:34 -0400 Subject: [PATCH 4/5] Remove backwards compatibility --- llama_cpp/llama.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 0deef7ab4..5c344c61a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -229,8 +229,8 @@ def __init__( n_batch: int = 512, n_threads: Optional[int] = None, n_threads_batch: Optional[int] = None, - rope_freq_base: float = 10000.0, - rope_freq_scale: float = 1.0, + rope_freq_base: float = 0.0, + rope_freq_scale: float = 0.0, mul_mat_q: bool = True, f16_kv: bool = True, logits_all: bool = False, @@ -327,14 +327,10 @@ def __init__( self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch self.context_params.rope_freq_base = ( - rope_freq_base - if rope_freq_base != 10000.0 and rope_freq_scale != 1.0 - else 0 + rope_freq_base if rope_freq_base != 0.0 and rope_freq_scale != 0.0 else 0 ) self.context_params.rope_freq_scale = ( - rope_freq_scale - if rope_freq_base != 10000.0 and rope_freq_scale != 1.0 - else 0 + rope_freq_scale if rope_freq_base != 0.0 and rope_freq_scale != 0.0 else 0 ) self.context_params.mul_mat_q = mul_mat_q self.context_params.f16_kv = f16_kv From 92144a6caeef70a3aa41cfb399d496cf098e3edc Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 29 Sep 2023 14:23:19 -0400 Subject: [PATCH 5/5] Check single val --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5c344c61a..966f79f8e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -327,10 +327,10 @@ def __init__( self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch self.context_params.rope_freq_base = ( - rope_freq_base if rope_freq_base != 0.0 and rope_freq_scale != 0.0 else 0 + rope_freq_base if rope_freq_base != 0.0 else 0 ) self.context_params.rope_freq_scale = ( - rope_freq_scale if rope_freq_base != 0.0 and rope_freq_scale != 0.0 else 0 + rope_freq_scale if rope_freq_scale != 0.0 else 0 ) self.context_params.mul_mat_q = mul_mat_q self.context_params.f16_kv = f16_kv