Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e811a81

Browse files
committed
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
2 parents ca8e3c9 + 5212fb0 commit e811a81

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

llama_cpp/llama.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def __init__(
410410
if self.verbose:
411411
print(f"Model metadata: {self.metadata}", file=sys.stderr)
412412

413-
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
414-
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
413+
eos_token_id = self.token_eos()
414+
bos_token_id = self.token_bos()
415415

416416
eos_token = self._model.token_get_text(eos_token_id)
417417
bos_token = self._model.token_get_text(bos_token_id)
@@ -961,9 +961,9 @@ def _create_completion(
961961

962962
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
963963
created: int = int(time.time())
964-
prefix_token_id: int = int(self.metadata.get("tokenizer.ggml.prefix_token_id", self._model.token_prefix()))
965-
middle_token_id: int = int(self.metadata.get("tokenizer.ggml.middle_token_id", self._model.token_middle()))
966-
suffix_token_id: int = int(self.metadata.get("tokenizer.ggml.suffix_token_id", self._model.token_suffix()))
964+
prefix_token_id: int = self._model.token_prefix()
965+
middle_token_id: int = self._model.token_middle()
966+
suffix_token_id: int = self._model.token_suffix()
967967
# If prompt is empty, initialize completion with BOS token to avoid
968968
# detokenization including a space at the beginning of the completion
969969
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
@@ -2084,3 +2084,19 @@ def __call__(
20842084
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
20852085
) -> bool:
20862086
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
2087+
2088+
2089+
class MinTokensLogitsProcessor(LogitsProcessor):
2090+
def __init__(self, min_tokens: int, token_eos: int):
2091+
self.min_tokens = min_tokens
2092+
self.token_eos = token_eos
2093+
self.prompt_tokens = None
2094+
2095+
def __call__(
2096+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2097+
) -> npt.NDArray[np.single]:
2098+
if self.prompt_tokens is None:
2099+
self.prompt_tokens = len(input_ids)
2100+
if len(input_ids) - self.prompt_tokens < self.min_tokens:
2101+
scores[self.token_eos] = -np.inf
2102+
return scores

llama_cpp/server/app.py

+20
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def create_completion(
275275
"best_of",
276276
"logit_bias_type",
277277
"user",
278+
"min_tokens",
278279
}
279280
kwargs = body.model_dump(exclude=exclude)
280281

@@ -288,6 +289,15 @@ async def create_completion(
288289
if body.grammar is not None:
289290
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
290291

292+
if body.min_tokens > 0:
293+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
294+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
295+
)
296+
if "logits_processor" not in kwargs:
297+
kwargs["logits_processor"] = _min_tokens_logits_processor
298+
else:
299+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
300+
291301
iterator_or_completion: Union[
292302
llama_cpp.CreateCompletionResponse,
293303
Iterator[llama_cpp.CreateCompletionStreamResponse],
@@ -445,6 +455,7 @@ async def create_chat_completion(
445455
"n",
446456
"logit_bias_type",
447457
"user",
458+
"min_tokens",
448459
}
449460
kwargs = body.model_dump(exclude=exclude)
450461
llama = llama_proxy(body.model)
@@ -458,6 +469,15 @@ async def create_chat_completion(
458469
if body.grammar is not None:
459470
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
460471

472+
if body.min_tokens > 0:
473+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
474+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
475+
)
476+
if "logits_processor" not in kwargs:
477+
kwargs["logits_processor"] = _min_tokens_logits_processor
478+
else:
479+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
480+
461481
iterator_or_completion: Union[
462482
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
463483
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)

llama_cpp/server/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
default=16, ge=1, description="The maximum number of tokens to generate."
1717
)
1818

19+
min_tokens_field = Field(
20+
default=0,
21+
ge=0,
22+
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
23+
)
24+
1925
temperature_field = Field(
2026
default=0.8,
2127
description="Adjust the randomness of the generated text.\n\n"
@@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
111117
max_tokens: Optional[int] = Field(
112118
default=16, ge=0, description="The maximum number of tokens to generate."
113119
)
120+
min_tokens: int = min_tokens_field
114121
temperature: float = temperature_field
115122
top_p: float = top_p_field
116123
min_p: float = min_p_field
@@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
206213
default=None,
207214
description="The maximum number of tokens to generate. Defaults to inf",
208215
)
216+
min_tokens: int = min_tokens_field
209217
logprobs: Optional[bool] = Field(
210218
default=False,
211219
description="Whether to output the logprobs or not. Default is True"

0 commit comments

Comments
 (0)