Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3b92419

Browse files
committed
Move cache classes to llama_cache submodule.
1 parent 6981597 commit 3b92419

File tree

2 files changed

+157
-140
lines changed

2 files changed

+157
-140
lines changed

llama_cpp/llama.py

+7-140
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import uuid
44
import time
55
import multiprocessing
6-
from abc import ABC, abstractmethod
76
from typing import (
87
List,
98
Optional,
@@ -12,16 +11,20 @@
1211
Sequence,
1312
Iterator,
1413
Deque,
15-
Tuple,
1614
Callable,
1715
)
18-
from collections import deque, OrderedDict
16+
from collections import deque
1917

20-
import diskcache
2118
import ctypes
2219

2320
from .llama_types import *
2421
from .llama_grammar import LlamaGrammar
22+
from .llama_cache import (
23+
BaseLlamaCache,
24+
LlamaCache, # type: ignore
25+
LlamaDiskCache, # type: ignore
26+
LlamaRAMCache, # type: ignore
27+
)
2528
import llama_cpp.llama_cpp as llama_cpp
2629
import llama_cpp.llama_chat_format as llama_chat_format
2730

@@ -31,142 +34,6 @@
3134
from ._utils import suppress_stdout_stderr
3235

3336

34-
class BaseLlamaCache(ABC):
35-
"""Base cache class for a llama.cpp model."""
36-
37-
def __init__(self, capacity_bytes: int = (2 << 30)):
38-
self.capacity_bytes = capacity_bytes
39-
40-
@property
41-
@abstractmethod
42-
def cache_size(self) -> int:
43-
raise NotImplementedError
44-
45-
def _find_longest_prefix_key(
46-
self,
47-
key: Tuple[int, ...],
48-
) -> Optional[Tuple[int, ...]]:
49-
pass
50-
51-
@abstractmethod
52-
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
53-
raise NotImplementedError
54-
55-
@abstractmethod
56-
def __contains__(self, key: Sequence[int]) -> bool:
57-
raise NotImplementedError
58-
59-
@abstractmethod
60-
def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None:
61-
raise NotImplementedError
62-
63-
64-
class LlamaRAMCache(BaseLlamaCache):
65-
"""Cache for a llama.cpp model using RAM."""
66-
67-
def __init__(self, capacity_bytes: int = (2 << 30)):
68-
super().__init__(capacity_bytes)
69-
self.capacity_bytes = capacity_bytes
70-
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
71-
72-
@property
73-
def cache_size(self):
74-
return sum([state.llama_state_size for state in self.cache_state.values()])
75-
76-
def _find_longest_prefix_key(
77-
self,
78-
key: Tuple[int, ...],
79-
) -> Optional[Tuple[int, ...]]:
80-
min_len = 0
81-
min_key = None
82-
keys = (
83-
(k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
84-
)
85-
for k, prefix_len in keys:
86-
if prefix_len > min_len:
87-
min_len = prefix_len
88-
min_key = k
89-
return min_key
90-
91-
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
92-
key = tuple(key)
93-
_key = self._find_longest_prefix_key(key)
94-
if _key is None:
95-
raise KeyError("Key not found")
96-
value = self.cache_state[_key]
97-
self.cache_state.move_to_end(_key)
98-
return value
99-
100-
def __contains__(self, key: Sequence[int]) -> bool:
101-
return self._find_longest_prefix_key(tuple(key)) is not None
102-
103-
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
104-
key = tuple(key)
105-
if key in self.cache_state:
106-
del self.cache_state[key]
107-
self.cache_state[key] = value
108-
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
109-
self.cache_state.popitem(last=False)
110-
111-
112-
# Alias for backwards compatibility
113-
LlamaCache = LlamaRAMCache
114-
115-
116-
class LlamaDiskCache(BaseLlamaCache):
117-
"""Cache for a llama.cpp model using disk."""
118-
119-
def __init__(
120-
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
121-
):
122-
super().__init__(capacity_bytes)
123-
self.cache = diskcache.Cache(cache_dir)
124-
125-
@property
126-
def cache_size(self):
127-
return int(self.cache.volume()) # type: ignore
128-
129-
def _find_longest_prefix_key(
130-
self,
131-
key: Tuple[int, ...],
132-
) -> Optional[Tuple[int, ...]]:
133-
min_len = 0
134-
min_key: Optional[Tuple[int, ...]] = None
135-
for k in self.cache.iterkeys(): # type: ignore
136-
prefix_len = Llama.longest_token_prefix(k, key)
137-
if prefix_len > min_len:
138-
min_len = prefix_len
139-
min_key = k # type: ignore
140-
return min_key
141-
142-
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
143-
key = tuple(key)
144-
_key = self._find_longest_prefix_key(key)
145-
if _key is None:
146-
raise KeyError("Key not found")
147-
value: "LlamaState" = self.cache.pop(_key) # type: ignore
148-
# NOTE: This puts an integer as key in cache, which breaks,
149-
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
150-
# self.cache.push(_key, side="front") # type: ignore
151-
return value
152-
153-
def __contains__(self, key: Sequence[int]) -> bool:
154-
return self._find_longest_prefix_key(tuple(key)) is not None
155-
156-
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
157-
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
158-
key = tuple(key)
159-
if key in self.cache:
160-
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
161-
del self.cache[key]
162-
self.cache[key] = value
163-
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
164-
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
165-
key_to_remove = next(iter(self.cache))
166-
del self.cache[key_to_remove]
167-
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
168-
169-
17037
class LlamaState:
17138
def __init__(
17239
self,

llama_cpp/llama_cache.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import sys
2+
from abc import ABC, abstractmethod
3+
from typing import (
4+
Optional,
5+
Sequence,
6+
Tuple,
7+
)
8+
from collections import OrderedDict
9+
10+
import diskcache
11+
12+
import llama_cpp.llama
13+
14+
from .llama_types import *
15+
16+
17+
class BaseLlamaCache(ABC):
18+
"""Base cache class for a llama.cpp model."""
19+
20+
def __init__(self, capacity_bytes: int = (2 << 30)):
21+
self.capacity_bytes = capacity_bytes
22+
23+
@property
24+
@abstractmethod
25+
def cache_size(self) -> int:
26+
raise NotImplementedError
27+
28+
def _find_longest_prefix_key(
29+
self,
30+
key: Tuple[int, ...],
31+
) -> Optional[Tuple[int, ...]]:
32+
pass
33+
34+
@abstractmethod
35+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
36+
raise NotImplementedError
37+
38+
@abstractmethod
39+
def __contains__(self, key: Sequence[int]) -> bool:
40+
raise NotImplementedError
41+
42+
@abstractmethod
43+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState") -> None:
44+
raise NotImplementedError
45+
46+
47+
class LlamaRAMCache(BaseLlamaCache):
48+
"""Cache for a llama.cpp model using RAM."""
49+
50+
def __init__(self, capacity_bytes: int = (2 << 30)):
51+
super().__init__(capacity_bytes)
52+
self.capacity_bytes = capacity_bytes
53+
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = OrderedDict()
54+
55+
@property
56+
def cache_size(self):
57+
return sum([state.llama_state_size for state in self.cache_state.values()])
58+
59+
def _find_longest_prefix_key(
60+
self,
61+
key: Tuple[int, ...],
62+
) -> Optional[Tuple[int, ...]]:
63+
min_len = 0
64+
min_key = None
65+
keys = (
66+
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
67+
)
68+
for k, prefix_len in keys:
69+
if prefix_len > min_len:
70+
min_len = prefix_len
71+
min_key = k
72+
return min_key
73+
74+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
75+
key = tuple(key)
76+
_key = self._find_longest_prefix_key(key)
77+
if _key is None:
78+
raise KeyError("Key not found")
79+
value = self.cache_state[_key]
80+
self.cache_state.move_to_end(_key)
81+
return value
82+
83+
def __contains__(self, key: Sequence[int]) -> bool:
84+
return self._find_longest_prefix_key(tuple(key)) is not None
85+
86+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
87+
key = tuple(key)
88+
if key in self.cache_state:
89+
del self.cache_state[key]
90+
self.cache_state[key] = value
91+
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
92+
self.cache_state.popitem(last=False)
93+
94+
95+
# Alias for backwards compatibility
96+
LlamaCache = LlamaRAMCache
97+
98+
99+
class LlamaDiskCache(BaseLlamaCache):
100+
"""Cache for a llama.cpp model using disk."""
101+
102+
def __init__(
103+
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
104+
):
105+
super().__init__(capacity_bytes)
106+
self.cache = diskcache.Cache(cache_dir)
107+
108+
@property
109+
def cache_size(self):
110+
return int(self.cache.volume()) # type: ignore
111+
112+
def _find_longest_prefix_key(
113+
self,
114+
key: Tuple[int, ...],
115+
) -> Optional[Tuple[int, ...]]:
116+
min_len = 0
117+
min_key: Optional[Tuple[int, ...]] = None
118+
for k in self.cache.iterkeys(): # type: ignore
119+
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
120+
if prefix_len > min_len:
121+
min_len = prefix_len
122+
min_key = k # type: ignore
123+
return min_key
124+
125+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
126+
key = tuple(key)
127+
_key = self._find_longest_prefix_key(key)
128+
if _key is None:
129+
raise KeyError("Key not found")
130+
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
131+
# NOTE: This puts an integer as key in cache, which breaks,
132+
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
133+
# self.cache.push(_key, side="front") # type: ignore
134+
return value
135+
136+
def __contains__(self, key: Sequence[int]) -> bool:
137+
return self._find_longest_prefix_key(tuple(key)) is not None
138+
139+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
140+
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
141+
key = tuple(key)
142+
if key in self.cache:
143+
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
144+
del self.cache[key]
145+
self.cache[key] = value
146+
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
147+
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
148+
key_to_remove = next(iter(self.cache))
149+
del self.cache[key_to_remove]
150+
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)

0 commit comments

Comments
 (0)