|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import os
|
2 | 4 | import sys
|
3 | 5 | import uuid
|
|
40 | 42 | )
|
41 | 43 |
|
42 | 44 |
|
43 |
| -class LlamaState: |
44 |
| - def __init__( |
45 |
| - self, |
46 |
| - input_ids: npt.NDArray[np.intc], |
47 |
| - scores: npt.NDArray[np.single], |
48 |
| - n_tokens: int, |
49 |
| - llama_state: bytes, |
50 |
| - llama_state_size: int, |
51 |
| - ): |
52 |
| - self.input_ids = input_ids |
53 |
| - self.scores = scores |
54 |
| - self.n_tokens = n_tokens |
55 |
| - self.llama_state = llama_state |
56 |
| - self.llama_state_size = llama_state_size |
57 |
| - |
58 |
| - |
59 |
| -LogitsProcessor = Callable[ |
60 |
| - [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] |
61 |
| -] |
62 |
| - |
63 |
| - |
64 |
| -class LogitsProcessorList(List[LogitsProcessor]): |
65 |
| - def __call__( |
66 |
| - self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] |
67 |
| - ) -> npt.NDArray[np.single]: |
68 |
| - for processor in self: |
69 |
| - scores = processor(input_ids, scores) |
70 |
| - return scores |
71 |
| - |
72 |
| - |
73 |
| -StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] |
74 |
| - |
75 |
| - |
76 |
| -class StoppingCriteriaList(List[StoppingCriteria]): |
77 |
| - def __call__( |
78 |
| - self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] |
79 |
| - ) -> bool: |
80 |
| - return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) |
81 |
| - |
82 |
| - |
83 | 45 | class Llama:
|
84 | 46 | """High-level Python wrapper for a llama.cpp model."""
|
85 | 47 |
|
@@ -1733,3 +1695,43 @@ def decode(self, tokens: List[int]) -> str:
|
1733 | 1695 | @classmethod
|
1734 | 1696 | def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
1735 | 1697 | return cls(Llama(model_path=path, vocab_only=True))
|
| 1698 | + |
| 1699 | + |
| 1700 | +class LlamaState: |
| 1701 | + def __init__( |
| 1702 | + self, |
| 1703 | + input_ids: npt.NDArray[np.intc], |
| 1704 | + scores: npt.NDArray[np.single], |
| 1705 | + n_tokens: int, |
| 1706 | + llama_state: bytes, |
| 1707 | + llama_state_size: int, |
| 1708 | + ): |
| 1709 | + self.input_ids = input_ids |
| 1710 | + self.scores = scores |
| 1711 | + self.n_tokens = n_tokens |
| 1712 | + self.llama_state = llama_state |
| 1713 | + self.llama_state_size = llama_state_size |
| 1714 | + |
| 1715 | + |
| 1716 | +LogitsProcessor = Callable[ |
| 1717 | + [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] |
| 1718 | +] |
| 1719 | + |
| 1720 | + |
| 1721 | +class LogitsProcessorList(List[LogitsProcessor]): |
| 1722 | + def __call__( |
| 1723 | + self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] |
| 1724 | + ) -> npt.NDArray[np.single]: |
| 1725 | + for processor in self: |
| 1726 | + scores = processor(input_ids, scores) |
| 1727 | + return scores |
| 1728 | + |
| 1729 | + |
| 1730 | +StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] |
| 1731 | + |
| 1732 | + |
| 1733 | +class StoppingCriteriaList(List[StoppingCriteria]): |
| 1734 | + def __call__( |
| 1735 | + self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] |
| 1736 | + ) -> bool: |
| 1737 | + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) |
0 commit comments