|
3 | 3 | import uuid
|
4 | 4 | import time
|
5 | 5 | import multiprocessing
|
6 |
| -from abc import ABC, abstractmethod |
7 | 6 | from typing import (
|
8 | 7 | List,
|
9 | 8 | Optional,
|
|
12 | 11 | Sequence,
|
13 | 12 | Iterator,
|
14 | 13 | Deque,
|
15 |
| - Tuple, |
16 | 14 | Callable,
|
17 | 15 | )
|
18 |
| -from collections import deque, OrderedDict |
| 16 | +from collections import deque |
19 | 17 |
|
20 |
| -import diskcache |
21 | 18 | import ctypes
|
22 | 19 |
|
23 | 20 | from .llama_types import *
|
24 | 21 | from .llama_grammar import LlamaGrammar
|
| 22 | +from .llama_cache import ( |
| 23 | + BaseLlamaCache, |
| 24 | + LlamaCache, # type: ignore |
| 25 | + LlamaDiskCache, # type: ignore |
| 26 | + LlamaRAMCache, # type: ignore |
| 27 | +) |
25 | 28 | import llama_cpp.llama_cpp as llama_cpp
|
26 | 29 | import llama_cpp.llama_chat_format as llama_chat_format
|
27 | 30 |
|
|
31 | 34 | from ._utils import suppress_stdout_stderr
|
32 | 35 |
|
33 | 36 |
|
34 |
| -class BaseLlamaCache(ABC): |
35 |
| - """Base cache class for a llama.cpp model.""" |
36 |
| - |
37 |
| - def __init__(self, capacity_bytes: int = (2 << 30)): |
38 |
| - self.capacity_bytes = capacity_bytes |
39 |
| - |
40 |
| - @property |
41 |
| - @abstractmethod |
42 |
| - def cache_size(self) -> int: |
43 |
| - raise NotImplementedError |
44 |
| - |
45 |
| - def _find_longest_prefix_key( |
46 |
| - self, |
47 |
| - key: Tuple[int, ...], |
48 |
| - ) -> Optional[Tuple[int, ...]]: |
49 |
| - pass |
50 |
| - |
51 |
| - @abstractmethod |
52 |
| - def __getitem__(self, key: Sequence[int]) -> "LlamaState": |
53 |
| - raise NotImplementedError |
54 |
| - |
55 |
| - @abstractmethod |
56 |
| - def __contains__(self, key: Sequence[int]) -> bool: |
57 |
| - raise NotImplementedError |
58 |
| - |
59 |
| - @abstractmethod |
60 |
| - def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None: |
61 |
| - raise NotImplementedError |
62 |
| - |
63 |
| - |
64 |
| -class LlamaRAMCache(BaseLlamaCache): |
65 |
| - """Cache for a llama.cpp model using RAM.""" |
66 |
| - |
67 |
| - def __init__(self, capacity_bytes: int = (2 << 30)): |
68 |
| - super().__init__(capacity_bytes) |
69 |
| - self.capacity_bytes = capacity_bytes |
70 |
| - self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() |
71 |
| - |
72 |
| - @property |
73 |
| - def cache_size(self): |
74 |
| - return sum([state.llama_state_size for state in self.cache_state.values()]) |
75 |
| - |
76 |
| - def _find_longest_prefix_key( |
77 |
| - self, |
78 |
| - key: Tuple[int, ...], |
79 |
| - ) -> Optional[Tuple[int, ...]]: |
80 |
| - min_len = 0 |
81 |
| - min_key = None |
82 |
| - keys = ( |
83 |
| - (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() |
84 |
| - ) |
85 |
| - for k, prefix_len in keys: |
86 |
| - if prefix_len > min_len: |
87 |
| - min_len = prefix_len |
88 |
| - min_key = k |
89 |
| - return min_key |
90 |
| - |
91 |
| - def __getitem__(self, key: Sequence[int]) -> "LlamaState": |
92 |
| - key = tuple(key) |
93 |
| - _key = self._find_longest_prefix_key(key) |
94 |
| - if _key is None: |
95 |
| - raise KeyError("Key not found") |
96 |
| - value = self.cache_state[_key] |
97 |
| - self.cache_state.move_to_end(_key) |
98 |
| - return value |
99 |
| - |
100 |
| - def __contains__(self, key: Sequence[int]) -> bool: |
101 |
| - return self._find_longest_prefix_key(tuple(key)) is not None |
102 |
| - |
103 |
| - def __setitem__(self, key: Sequence[int], value: "LlamaState"): |
104 |
| - key = tuple(key) |
105 |
| - if key in self.cache_state: |
106 |
| - del self.cache_state[key] |
107 |
| - self.cache_state[key] = value |
108 |
| - while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: |
109 |
| - self.cache_state.popitem(last=False) |
110 |
| - |
111 |
| - |
112 |
| -# Alias for backwards compatibility |
113 |
| -LlamaCache = LlamaRAMCache |
114 |
| - |
115 |
| - |
116 |
| -class LlamaDiskCache(BaseLlamaCache): |
117 |
| - """Cache for a llama.cpp model using disk.""" |
118 |
| - |
119 |
| - def __init__( |
120 |
| - self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) |
121 |
| - ): |
122 |
| - super().__init__(capacity_bytes) |
123 |
| - self.cache = diskcache.Cache(cache_dir) |
124 |
| - |
125 |
| - @property |
126 |
| - def cache_size(self): |
127 |
| - return int(self.cache.volume()) # type: ignore |
128 |
| - |
129 |
| - def _find_longest_prefix_key( |
130 |
| - self, |
131 |
| - key: Tuple[int, ...], |
132 |
| - ) -> Optional[Tuple[int, ...]]: |
133 |
| - min_len = 0 |
134 |
| - min_key: Optional[Tuple[int, ...]] = None |
135 |
| - for k in self.cache.iterkeys(): # type: ignore |
136 |
| - prefix_len = Llama.longest_token_prefix(k, key) |
137 |
| - if prefix_len > min_len: |
138 |
| - min_len = prefix_len |
139 |
| - min_key = k # type: ignore |
140 |
| - return min_key |
141 |
| - |
142 |
| - def __getitem__(self, key: Sequence[int]) -> "LlamaState": |
143 |
| - key = tuple(key) |
144 |
| - _key = self._find_longest_prefix_key(key) |
145 |
| - if _key is None: |
146 |
| - raise KeyError("Key not found") |
147 |
| - value: "LlamaState" = self.cache.pop(_key) # type: ignore |
148 |
| - # NOTE: This puts an integer as key in cache, which breaks, |
149 |
| - # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens |
150 |
| - # self.cache.push(_key, side="front") # type: ignore |
151 |
| - return value |
152 |
| - |
153 |
| - def __contains__(self, key: Sequence[int]) -> bool: |
154 |
| - return self._find_longest_prefix_key(tuple(key)) is not None |
155 |
| - |
156 |
| - def __setitem__(self, key: Sequence[int], value: "LlamaState"): |
157 |
| - print("LlamaDiskCache.__setitem__: called", file=sys.stderr) |
158 |
| - key = tuple(key) |
159 |
| - if key in self.cache: |
160 |
| - print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) |
161 |
| - del self.cache[key] |
162 |
| - self.cache[key] = value |
163 |
| - print("LlamaDiskCache.__setitem__: set", file=sys.stderr) |
164 |
| - while self.cache_size > self.capacity_bytes and len(self.cache) > 0: |
165 |
| - key_to_remove = next(iter(self.cache)) |
166 |
| - del self.cache[key_to_remove] |
167 |
| - print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) |
168 |
| - |
169 |
| - |
170 | 37 | class LlamaState:
|
171 | 38 | def __init__(
|
172 | 39 | self,
|
|
0 commit comments