Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dbcf64c

Browse files
CISCabetlen
andauthored
feat: Support SPM infill (abetlen#1492)
* Support SPM infill * typo-- * one less layer of parenthesis necessary * new required internals * manually add bos/eos if model requires it * add bos even when unknown This is identical behaviour to llama.cpp I guess any model that doesn't use BOS is recent enough to have the add_bos_token metadata. * don't add bos/eos on non-infill pre-tokenized prompt * add tokenizer hack to remove leading space in suffix * I keep forgetting metadata are strings * check if bos exists * add example * add cls/sep instead of bos/eos for WPM vocab * simplify * color-code filtered suffix --------- Co-authored-by: Andrei Betlen <[email protected]>
1 parent e342161 commit dbcf64c

File tree

3 files changed

+87
-27
lines changed

3 files changed

+87
-27
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import argparse
2+
3+
from llama_cpp import Llama
4+
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
7+
parser.add_argument("-p", "--prompt", type=str, default="def add(")
8+
parser.add_argument("-s", "--suffix", type=str, default="\n return sum\n\n")
9+
parser.add_argument("-i", "--spm-infill", action='store_true')
10+
args = parser.parse_args()
11+
12+
llm = Llama(model_path=args.model, n_gpu_layers=-1, spm_infill=args.spm_infill)
13+
14+
output = llm.create_completion(
15+
temperature = 0.0,
16+
repeat_penalty = 1.0,
17+
prompt = args.prompt,
18+
suffix = args.suffix,
19+
)
20+
21+
# Models sometimes repeat suffix in response, attempt to filter that
22+
response = output["choices"][0]["text"]
23+
response_stripped = response.rstrip()
24+
unwanted_response_suffix = args.suffix.rstrip()
25+
unwanted_response_length = len(unwanted_response_suffix)
26+
27+
filtered = False
28+
if unwanted_response_suffix and response_stripped[-unwanted_response_length:] == unwanted_response_suffix:
29+
response = response_stripped[:-unwanted_response_length]
30+
filtered = True
31+
32+
print(f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m")
33+

llama_cpp/_internals.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ def token_eot(self) -> int:
170170
assert self.model is not None
171171
return llama_cpp.llama_token_eot(self.model)
172172

173+
def add_bos_token(self) -> int:
174+
assert self.model is not None
175+
return llama_cpp.llama_add_bos_token(self.model)
176+
177+
def add_eos_token(self) -> int:
178+
assert self.model is not None
179+
return llama_cpp.llama_add_eos_token(self.model)
180+
173181
# Tokenization
174182

175183
def tokenize(self, text: bytes, add_bos: bool, special: bool):

llama_cpp/llama.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
type_k: Optional[int] = None,
116116
type_v: Optional[int] = None,
117117
# Misc
118+
spm_infill: bool = False,
118119
verbose: bool = True,
119120
# Extra Params
120121
**kwargs, # type: ignore
@@ -185,6 +186,7 @@ def __init__(
185186
verbose: Print verbose output to stderr.
186187
type_k: KV cache data type for K (default: f16)
187188
type_v: KV cache data type for V (default: f16)
189+
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
188190
189191
Raises:
190192
ValueError: If the model path does not exist.
@@ -343,6 +345,8 @@ def __init__(
343345
self.lora_scale = lora_scale
344346
self.lora_path = lora_path
345347

348+
self.spm_infill = spm_infill
349+
346350
if not os.path.exists(model_path):
347351
raise ValueError(f"Model path does not exist: {model_path}")
348352

@@ -972,14 +976,33 @@ def _create_completion(
972976

973977
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
974978
created: int = int(time.time())
979+
bos_token_id: int = self.token_bos()
980+
cls_token_id: int = self._model.token_cls()
981+
sep_token_id: int = self._model.token_sep()
975982
prefix_token_id: int = self._model.token_prefix()
976983
middle_token_id: int = self._model.token_middle()
977984
suffix_token_id: int = self._model.token_suffix()
985+
add_space_prefix: bool = self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
986+
bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
987+
eos_tokens: List[int] = [sep_token_id if sep_token_id != -1 else self.token_eos()]
988+
989+
if (isinstance(prompt, list) and suffix is None) or self._model.add_bos_token() == 0 or bos_tokens[:1] == [-1]:
990+
bos_tokens = []
991+
992+
if (isinstance(prompt, list) and suffix is None) or (self._model.add_eos_token() != 1 and sep_token_id == -1):
993+
eos_tokens = []
994+
995+
suffix_space_prefix: int = 0
996+
# Tokenizer hack to remove leading space
997+
if add_space_prefix and suffix_token_id >= 0 and suffix:
998+
suffix = "☺" + suffix
999+
suffix_space_prefix = 2
1000+
9781001
# If prompt is empty, initialize completion with BOS token to avoid
9791002
# detokenization including a space at the beginning of the completion
980-
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
1003+
completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
9811004
# Add blank space to start of prompt to match OG llama tokenizer
982-
prompt_tokens: List[int] = (
1005+
prefix_tokens: List[int] = (
9831006
(
9841007
[prefix_token_id]
9851008
if prefix_token_id >= 0 and suffix is not None
@@ -988,38 +1011,33 @@ def _create_completion(
9881011
+
9891012
(
9901013
(
991-
self.tokenize(prompt.encode("utf-8"), add_bos=(prefix_token_id < 0 or suffix is None), special=(prefix_token_id < 0 or suffix is None))
1014+
self.tokenize(prompt.encode("utf-8"), add_bos=False, special=(prefix_token_id < 0 or suffix is None))
9921015
if prompt != ""
993-
else (
994-
[]
995-
if prefix_token_id >= 0 and suffix is not None
996-
else [self.token_bos()]
997-
)
1016+
else []
9981017
)
9991018
if isinstance(prompt, str)
10001019
else prompt
10011020
)
1002-
+
1021+
)
1022+
suffix_tokens: List[int] = (
10031023
(
1024+
[suffix_token_id]
1025+
+
10041026
(
1005-
[suffix_token_id]
1006-
+
1007-
(
1008-
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)
1009-
if suffix
1010-
else []
1011-
)
1027+
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[suffix_space_prefix:]
1028+
if suffix
1029+
else []
10121030
)
1013-
if suffix_token_id >= 0 and suffix is not None
1014-
else []
1015-
)
1016-
+
1017-
(
1018-
[middle_token_id]
1019-
if middle_token_id >= 0 and suffix is not None
1020-
else []
10211031
)
1032+
if suffix_token_id >= 0 and suffix is not None
1033+
else []
1034+
)
1035+
middle_tokens: List[int] = (
1036+
[middle_token_id]
1037+
if middle_token_id >= 0 and suffix is not None
1038+
else []
10221039
)
1040+
prompt_tokens: List[int] = bos_tokens + ((suffix_tokens + prefix_tokens + middle_tokens) if self.spm_infill else (prefix_tokens + suffix_tokens + middle_tokens)) + eos_tokens
10231041
text: bytes = b""
10241042
returned_tokens: int = 0
10251043
stop = (
@@ -1176,7 +1194,7 @@ def logit_bias_processor(
11761194
# not sure how to handle this branch when dealing
11771195
# with CJK output, so keep it unchanged
11781196
for token in remaining_tokens:
1179-
if token == self.token_bos():
1197+
if token == bos_token_id:
11801198
continue
11811199
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
11821200
# Check if stop sequence is in the token
@@ -1303,7 +1321,7 @@ def logit_bias_processor(
13031321

13041322
logprobs_or_none: Optional[CompletionLogprobs] = None
13051323
if logprobs is not None:
1306-
if token == self.token_bos():
1324+
if token == bos_token_id:
13071325
continue
13081326
token_str = self.detokenize([token]).decode(
13091327
"utf-8", errors="ignore"
@@ -1431,7 +1449,7 @@ def logit_bias_processor(
14311449
for idx, (token, token_str, logprobs_token) in enumerate(
14321450
zip(all_tokens, all_token_strs, all_logprobs)
14331451
):
1434-
if token == self.token_bos():
1452+
if token == bos_token_id:
14351453
continue
14361454
text_offsets.append(
14371455
text_offset
@@ -1858,6 +1876,7 @@ def __getstate__(self):
18581876
type_k=self.context_params.type_k,
18591877
type_v=self.context_params.type_v,
18601878
# Misc
1879+
spm_infill=self.spm_infill,
18611880
verbose=self.verbose,
18621881
)
18631882

0 commit comments

Comments
 (0)