@@ -410,8 +410,8 @@ def __init__(
410
410
if self .verbose :
411
411
print (f"Model metadata: { self .metadata } " , file = sys .stderr )
412
412
413
- eos_token_id = int ( self .metadata . get ( "tokenizer.ggml.eos_token_id" , self . token_eos ()) )
414
- bos_token_id = int ( self .metadata . get ( "tokenizer.ggml.bos_token_id" , self . token_bos ()) )
413
+ eos_token_id = self .token_eos ()
414
+ bos_token_id = self .token_bos ()
415
415
416
416
eos_token = self ._model .token_get_text (eos_token_id )
417
417
bos_token = self ._model .token_get_text (bos_token_id )
@@ -961,9 +961,9 @@ def _create_completion(
961
961
962
962
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
963
963
created : int = int (time .time ())
964
- prefix_token_id : int = int ( self .metadata . get ( "tokenizer.ggml.prefix_token_id" , self . _model .token_prefix ()) )
965
- middle_token_id : int = int ( self .metadata . get ( "tokenizer.ggml.middle_token_id" , self . _model .token_middle ()) )
966
- suffix_token_id : int = int ( self .metadata . get ( "tokenizer.ggml.suffix_token_id" , self . _model .token_suffix ()) )
964
+ prefix_token_id : int = self ._model .token_prefix ()
965
+ middle_token_id : int = self ._model .token_middle ()
966
+ suffix_token_id : int = self ._model .token_suffix ()
967
967
# If prompt is empty, initialize completion with BOS token to avoid
968
968
# detokenization including a space at the beginning of the completion
969
969
completion_tokens : List [int ] = [] if len (prompt ) > 0 else [self .token_bos ()]
@@ -2084,3 +2084,19 @@ def __call__(
2084
2084
self , input_ids : npt .NDArray [np .intc ], logits : npt .NDArray [np .single ]
2085
2085
) -> bool :
2086
2086
return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
2087
+
2088
+
2089
+ class MinTokensLogitsProcessor (LogitsProcessor ):
2090
+ def __init__ (self , min_tokens : int , token_eos : int ):
2091
+ self .min_tokens = min_tokens
2092
+ self .token_eos = token_eos
2093
+ self .prompt_tokens = None
2094
+
2095
+ def __call__ (
2096
+ self , input_ids : npt .NDArray [np .intc ], scores : npt .NDArray [np .single ]
2097
+ ) -> npt .NDArray [np .single ]:
2098
+ if self .prompt_tokens is None :
2099
+ self .prompt_tokens = len (input_ids )
2100
+ if len (input_ids ) - self .prompt_tokens < self .min_tokens :
2101
+ scores [self .token_eos ] = - np .inf
2102
+ return scores
0 commit comments