@@ -115,6 +115,7 @@ def __init__(
115
115
type_k : Optional [int ] = None ,
116
116
type_v : Optional [int ] = None ,
117
117
# Misc
118
+ spm_infill : bool = False ,
118
119
verbose : bool = True ,
119
120
# Extra Params
120
121
** kwargs , # type: ignore
@@ -185,6 +186,7 @@ def __init__(
185
186
verbose: Print verbose output to stderr.
186
187
type_k: KV cache data type for K (default: f16)
187
188
type_v: KV cache data type for V (default: f16)
189
+ spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
188
190
189
191
Raises:
190
192
ValueError: If the model path does not exist.
@@ -343,6 +345,8 @@ def __init__(
343
345
self .lora_scale = lora_scale
344
346
self .lora_path = lora_path
345
347
348
+ self .spm_infill = spm_infill
349
+
346
350
if not os .path .exists (model_path ):
347
351
raise ValueError (f"Model path does not exist: { model_path } " )
348
352
@@ -972,14 +976,33 @@ def _create_completion(
972
976
973
977
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
974
978
created : int = int (time .time ())
979
+ bos_token_id : int = self .token_bos ()
980
+ cls_token_id : int = self ._model .token_cls ()
981
+ sep_token_id : int = self ._model .token_sep ()
975
982
prefix_token_id : int = self ._model .token_prefix ()
976
983
middle_token_id : int = self ._model .token_middle ()
977
984
suffix_token_id : int = self ._model .token_suffix ()
985
+ add_space_prefix : bool = self .metadata .get ("tokenizer.ggml.add_space_prefix" , "true" ) == "true"
986
+ bos_tokens : List [int ] = [cls_token_id if cls_token_id != - 1 else bos_token_id ]
987
+ eos_tokens : List [int ] = [sep_token_id if sep_token_id != - 1 else self .token_eos ()]
988
+
989
+ if (isinstance (prompt , list ) and suffix is None ) or self ._model .add_bos_token () == 0 or bos_tokens [:1 ] == [- 1 ]:
990
+ bos_tokens = []
991
+
992
+ if (isinstance (prompt , list ) and suffix is None ) or (self ._model .add_eos_token () != 1 and sep_token_id == - 1 ):
993
+ eos_tokens = []
994
+
995
+ suffix_space_prefix : int = 0
996
+ # Tokenizer hack to remove leading space
997
+ if add_space_prefix and suffix_token_id >= 0 and suffix :
998
+ suffix = "☺" + suffix
999
+ suffix_space_prefix = 2
1000
+
978
1001
# If prompt is empty, initialize completion with BOS token to avoid
979
1002
# detokenization including a space at the beginning of the completion
980
- completion_tokens : List [int ] = [] if len (prompt ) > 0 else [self . token_bos () ]
1003
+ completion_tokens : List [int ] = [] if len (prompt ) > 0 else [bos_token_id ]
981
1004
# Add blank space to start of prompt to match OG llama tokenizer
982
- prompt_tokens : List [int ] = (
1005
+ prefix_tokens : List [int ] = (
983
1006
(
984
1007
[prefix_token_id ]
985
1008
if prefix_token_id >= 0 and suffix is not None
@@ -988,38 +1011,33 @@ def _create_completion(
988
1011
+
989
1012
(
990
1013
(
991
- self .tokenize (prompt .encode ("utf-8" ), add_bos = ( prefix_token_id < 0 or suffix is None ) , special = (prefix_token_id < 0 or suffix is None ))
1014
+ self .tokenize (prompt .encode ("utf-8" ), add_bos = False , special = (prefix_token_id < 0 or suffix is None ))
992
1015
if prompt != ""
993
- else (
994
- []
995
- if prefix_token_id >= 0 and suffix is not None
996
- else [self .token_bos ()]
997
- )
1016
+ else []
998
1017
)
999
1018
if isinstance (prompt , str )
1000
1019
else prompt
1001
1020
)
1002
- +
1021
+ )
1022
+ suffix_tokens : List [int ] = (
1003
1023
(
1024
+ [suffix_token_id ]
1025
+ +
1004
1026
(
1005
- [suffix_token_id ]
1006
- +
1007
- (
1008
- self .tokenize (suffix .encode ("utf-8" ), add_bos = False , special = False )
1009
- if suffix
1010
- else []
1011
- )
1027
+ self .tokenize (suffix .encode ("utf-8" ), add_bos = False , special = False )[suffix_space_prefix :]
1028
+ if suffix
1029
+ else []
1012
1030
)
1013
- if suffix_token_id >= 0 and suffix is not None
1014
- else []
1015
- )
1016
- +
1017
- (
1018
- [middle_token_id ]
1019
- if middle_token_id >= 0 and suffix is not None
1020
- else []
1021
1031
)
1032
+ if suffix_token_id >= 0 and suffix is not None
1033
+ else []
1034
+ )
1035
+ middle_tokens : List [int ] = (
1036
+ [middle_token_id ]
1037
+ if middle_token_id >= 0 and suffix is not None
1038
+ else []
1022
1039
)
1040
+ prompt_tokens : List [int ] = bos_tokens + ((suffix_tokens + prefix_tokens + middle_tokens ) if self .spm_infill else (prefix_tokens + suffix_tokens + middle_tokens )) + eos_tokens
1023
1041
text : bytes = b""
1024
1042
returned_tokens : int = 0
1025
1043
stop = (
@@ -1176,7 +1194,7 @@ def logit_bias_processor(
1176
1194
# not sure how to handle this branch when dealing
1177
1195
# with CJK output, so keep it unchanged
1178
1196
for token in remaining_tokens :
1179
- if token == self . token_bos () :
1197
+ if token == bos_token_id :
1180
1198
continue
1181
1199
token_end_position += len (self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ]))
1182
1200
# Check if stop sequence is in the token
@@ -1303,7 +1321,7 @@ def logit_bias_processor(
1303
1321
1304
1322
logprobs_or_none : Optional [CompletionLogprobs ] = None
1305
1323
if logprobs is not None :
1306
- if token == self . token_bos () :
1324
+ if token == bos_token_id :
1307
1325
continue
1308
1326
token_str = self .detokenize ([token ]).decode (
1309
1327
"utf-8" , errors = "ignore"
@@ -1431,7 +1449,7 @@ def logit_bias_processor(
1431
1449
for idx , (token , token_str , logprobs_token ) in enumerate (
1432
1450
zip (all_tokens , all_token_strs , all_logprobs )
1433
1451
):
1434
- if token == self . token_bos () :
1452
+ if token == bos_token_id :
1435
1453
continue
1436
1454
text_offsets .append (
1437
1455
text_offset
@@ -1858,6 +1876,7 @@ def __getstate__(self):
1858
1876
type_k = self .context_params .type_k ,
1859
1877
type_v = self .context_params .type_v ,
1860
1878
# Misc
1879
+ spm_infill = self .spm_infill ,
1861
1880
verbose = self .verbose ,
1862
1881
)
1863
1882
0 commit comments