Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7b46bb5

Browse files
committed
Re-order classes in llama.py
1 parent cc4630e commit 7b46bb5

File tree

1 file changed

+42
-40
lines changed

1 file changed

+42
-40
lines changed

llama_cpp/llama.py

+42-40
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import os
24
import sys
35
import uuid
@@ -40,46 +42,6 @@
4042
)
4143

4244

43-
class LlamaState:
44-
def __init__(
45-
self,
46-
input_ids: npt.NDArray[np.intc],
47-
scores: npt.NDArray[np.single],
48-
n_tokens: int,
49-
llama_state: bytes,
50-
llama_state_size: int,
51-
):
52-
self.input_ids = input_ids
53-
self.scores = scores
54-
self.n_tokens = n_tokens
55-
self.llama_state = llama_state
56-
self.llama_state_size = llama_state_size
57-
58-
59-
LogitsProcessor = Callable[
60-
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
61-
]
62-
63-
64-
class LogitsProcessorList(List[LogitsProcessor]):
65-
def __call__(
66-
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
67-
) -> npt.NDArray[np.single]:
68-
for processor in self:
69-
scores = processor(input_ids, scores)
70-
return scores
71-
72-
73-
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
74-
75-
76-
class StoppingCriteriaList(List[StoppingCriteria]):
77-
def __call__(
78-
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
79-
) -> bool:
80-
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
81-
82-
8345
class Llama:
8446
"""High-level Python wrapper for a llama.cpp model."""
8547

@@ -1733,3 +1695,43 @@ def decode(self, tokens: List[int]) -> str:
17331695
@classmethod
17341696
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
17351697
return cls(Llama(model_path=path, vocab_only=True))
1698+
1699+
1700+
class LlamaState:
1701+
def __init__(
1702+
self,
1703+
input_ids: npt.NDArray[np.intc],
1704+
scores: npt.NDArray[np.single],
1705+
n_tokens: int,
1706+
llama_state: bytes,
1707+
llama_state_size: int,
1708+
):
1709+
self.input_ids = input_ids
1710+
self.scores = scores
1711+
self.n_tokens = n_tokens
1712+
self.llama_state = llama_state
1713+
self.llama_state_size = llama_state_size
1714+
1715+
1716+
LogitsProcessor = Callable[
1717+
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
1718+
]
1719+
1720+
1721+
class LogitsProcessorList(List[LogitsProcessor]):
1722+
def __call__(
1723+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
1724+
) -> npt.NDArray[np.single]:
1725+
for processor in self:
1726+
scores = processor(input_ids, scores)
1727+
return scores
1728+
1729+
1730+
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
1731+
1732+
1733+
class StoppingCriteriaList(List[StoppingCriteria]):
1734+
def __call__(
1735+
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
1736+
) -> bool:
1737+
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])

0 commit comments

Comments
 (0)