Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 850416a

Browse files
committed
Merge branch 'main' into batch-processing
2 parents e7ef07d + 6f08021 commit 850416a

11 files changed

+321
-27
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This package provides:
1414
- High-level Python API for text completion
1515
- OpenAI-like API
1616
- [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp)
17+
- [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html)
1718
- OpenAI compatible web server
1819
- [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)
1920
- [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling)

docs/templates.md

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Templates
2+
3+
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
4+
5+
## Introduction
6+
7+
- Brief explanation of the `llama-cpp-python` project's need for a templating system.
8+
- Overview of the `llama-2` model's interaction with templating.
9+
10+
## Jinja2 Dependency Integration
11+
12+
- Rationale for choosing Jinja2 as the templating engine.
13+
- Compatibility with Hugging Face's `transformers`.
14+
- Desire for advanced templating features and simplicity.
15+
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
16+
17+
## Template Management Refactor
18+
19+
- Summary of the refactor and the motivation behind it.
20+
- Description of the new chat handler selection logic:
21+
1. Preference for a user-specified `chat_handler`.
22+
2. Fallback to a user-specified `chat_format`.
23+
3. Defaulting to a chat format from a `.gguf` file if available.
24+
4. Utilizing the `llama2` default chat format as the final fallback.
25+
- Ensuring backward compatibility throughout the refactor.
26+
27+
## Implementation Details
28+
29+
- In-depth look at the new `AutoChatFormatter` class.
30+
- Example code snippets showing how to utilize the Jinja2 environment and templates.
31+
- Guidance on how to provide custom templates or use defaults.
32+
33+
## Testing and Validation
34+
35+
- Outline of the testing strategy to ensure seamless integration.
36+
- Steps for validating backward compatibility with existing implementations.
37+
38+
## Benefits and Impact
39+
40+
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
41+
- Discussion of the potential impact on current users and contributors.
42+
43+
## Future Work
44+
45+
- Exploration of how templating can evolve within the project.
46+
- Consideration of additional features or optimizations for the templating engine.
47+
- Mechanisms for community feedback on the templating system.
48+
49+
## Conclusion
50+
51+
- Final thoughts on the integration of Jinja2 templating.
52+
- Call to action for community involvement and feedback.

llama_cpp/_internals.py

-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
class _LlamaModel:
2828
"""Intermediate Python wrapper for a llama.cpp llama_model.
29-
3029
NOTE: For stability it's recommended you use the Llama class instead."""
3130

3231
_llama_free_model = None
@@ -213,7 +212,6 @@ def default_params():
213212

214213
class _LlamaContext:
215214
"""Intermediate Python wrapper for a llama.cpp llama_context.
216-
217215
NOTE: For stability it's recommended you use the Llama class instead."""
218216

219217
_llama_free = None

llama_cpp/llama.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import uuid
66
import time
77
import multiprocessing
8-
98
from typing import (
109
List,
1110
Optional,
@@ -25,18 +24,23 @@
2524

2625
from .llama_types import *
2726
from .llama_grammar import LlamaGrammar
28-
from .llama_cache import BaseLlamaCache
29-
27+
from .llama_cache import (
28+
BaseLlamaCache,
29+
LlamaCache, # type: ignore
30+
LlamaDiskCache, # type: ignore
31+
LlamaRAMCache, # type: ignore
32+
)
3033
import llama_cpp.llama_cpp as llama_cpp
3134
import llama_cpp.llama_chat_format as llama_chat_format
3235

3336
from ._internals import (
34-
_LlamaModel,
35-
_LlamaContext,
36-
_LlamaBatch,
37+
_LlamaModel, # type: ignore
38+
_LlamaContext, # type: ignore
39+
_LlamaBatch, # type: ignore
3740
_LlamaTokenDataArray, # type: ignore
38-
_LlamaSamplingParams,
39-
_LlamaSamplingContext,
41+
_LlamaSamplingParams, # type: ignore
42+
_LlamaSamplingContext, # type: ignore
43+
4044
)
4145
from ._utils import suppress_stdout_stderr
4246

llama_cpp/llama_cache.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sys
2-
32
from abc import ABC, abstractmethod
43
from typing import (
54
Optional,
@@ -8,8 +7,12 @@
87
)
98
from collections import OrderedDict
109

10+
import diskcache
11+
1112
import llama_cpp.llama
1213

14+
from .llama_types import *
15+
1316

1417
class BaseLlamaCache(ABC):
1518
"""Base cache class for a llama.cpp model."""
@@ -37,9 +40,7 @@ def __contains__(self, key: Sequence[int]) -> bool:
3740
raise NotImplementedError
3841

3942
@abstractmethod
40-
def __setitem__(
41-
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
42-
) -> None:
43+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState") -> None:
4344
raise NotImplementedError
4445

4546

@@ -49,9 +50,7 @@ class LlamaRAMCache(BaseLlamaCache):
4950
def __init__(self, capacity_bytes: int = (2 << 30)):
5051
super().__init__(capacity_bytes)
5152
self.capacity_bytes = capacity_bytes
52-
self.cache_state: OrderedDict[
53-
Tuple[int, ...], "llama_cpp.llama.LlamaState"
54-
] = OrderedDict()
53+
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = OrderedDict()
5554

5655
@property
5756
def cache_size(self):
@@ -64,8 +63,7 @@ def _find_longest_prefix_key(
6463
min_len = 0
6564
min_key = None
6665
keys = (
67-
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
68-
for k in self.cache_state.keys()
66+
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
6967
)
7068
for k, prefix_len in keys:
7169
if prefix_len > min_len:
@@ -104,8 +102,6 @@ class LlamaDiskCache(BaseLlamaCache):
104102
def __init__(
105103
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
106104
):
107-
import diskcache
108-
109105
super().__init__(capacity_bytes)
110106
self.cache = diskcache.Cache(cache_dir)
111107

@@ -131,7 +127,7 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
131127
_key = self._find_longest_prefix_key(key)
132128
if _key is None:
133129
raise KeyError("Key not found")
134-
value: "LlamaState" = self.cache.pop(_key) # type: ignore
130+
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
135131
# NOTE: This puts an integer as key in cache, which breaks,
136132
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
137133
# self.cache.push(_key, side="front") # type: ignore

llama_cpp/llama_jinja_format.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
llama_cpp/llama_jinja_format.py
3+
"""
4+
import dataclasses
5+
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
6+
7+
import jinja2
8+
from jinja2 import Template
9+
10+
# NOTE: We sacrifice readability for usability.
11+
# It will fail to work as expected if we attempt to format it in a readable way.
12+
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
13+
14+
15+
class MetaSingleton(type):
16+
"""
17+
Metaclass for implementing the Singleton pattern.
18+
"""
19+
20+
_instances = {}
21+
22+
def __call__(cls, *args, **kwargs):
23+
if cls not in cls._instances:
24+
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
25+
return cls._instances[cls]
26+
27+
28+
class Singleton(object, metaclass=MetaSingleton):
29+
"""
30+
Base class for implementing the Singleton pattern.
31+
"""
32+
33+
def __init__(self):
34+
super(Singleton, self).__init__()
35+
36+
37+
@dataclasses.dataclass
38+
class ChatFormatterResponse:
39+
prompt: str
40+
stop: Optional[Union[str, List[str]]] = None
41+
42+
43+
# Base Chat Formatter Protocol
44+
class ChatFormatterInterface(Protocol):
45+
def __init__(self, template: Optional[object] = None):
46+
...
47+
48+
def __call__(
49+
self,
50+
messages: List[Dict[str, str]],
51+
**kwargs,
52+
) -> ChatFormatterResponse:
53+
...
54+
55+
@property
56+
def template(self) -> str:
57+
...
58+
59+
60+
class AutoChatFormatter(ChatFormatterInterface):
61+
def __init__(
62+
self,
63+
template: Optional[str] = None,
64+
template_class: Optional[Template] = None,
65+
):
66+
if template is not None:
67+
self._template = template
68+
else:
69+
self._template = llama2_template # default template
70+
71+
self._environment = jinja2.Environment(
72+
loader=jinja2.BaseLoader(),
73+
trim_blocks=True,
74+
lstrip_blocks=True,
75+
).from_string(
76+
self._template,
77+
template_class=template_class,
78+
)
79+
80+
def __call__(
81+
self,
82+
messages: List[Dict[str, str]],
83+
**kwargs: Any,
84+
) -> ChatFormatterResponse:
85+
formatted_sequence = self._environment.render(messages=messages, **kwargs)
86+
return ChatFormatterResponse(prompt=formatted_sequence)
87+
88+
@property
89+
def template(self) -> str:
90+
return self._template
91+
92+
93+
class FormatterNotFoundException(Exception):
94+
pass
95+
96+
97+
class ChatFormatterFactory(Singleton):
98+
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
99+
100+
def register_formatter(
101+
self,
102+
name: str,
103+
formatter_callable: Callable[[], ChatFormatterInterface],
104+
overwrite=False,
105+
):
106+
if not overwrite and name in self._chat_formatters:
107+
raise ValueError(
108+
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
109+
)
110+
self._chat_formatters[name] = formatter_callable
111+
112+
def unregister_formatter(self, name: str):
113+
if name in self._chat_formatters:
114+
del self._chat_formatters[name]
115+
else:
116+
raise ValueError(f"No formatter registered under the name '{name}'.")
117+
118+
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
119+
try:
120+
formatter_callable = self._chat_formatters[name]
121+
return formatter_callable()
122+
except KeyError:
123+
raise FormatterNotFoundException(
124+
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
125+
)
126+
127+
128+
# Define a chat format class
129+
class Llama2Formatter(AutoChatFormatter):
130+
def __init__(self):
131+
super().__init__(llama2_template)
132+
133+
134+
# With the Singleton pattern applied, regardless of where or how many times
135+
# ChatFormatterFactory() is called, it will always return the same instance
136+
# of the factory, ensuring that the factory's state is consistent throughout
137+
# the application.
138+
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)

llama_cpp/server/app.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,36 @@ async def authenticate(
197197

198198

199199
@router.post(
200-
"/v1/completions", summary="Completion", dependencies=[Depends(authenticate)]
200+
"/v1/completions",
201+
summary="Completion",
202+
dependencies=[Depends(authenticate)],
203+
response_model= Union[
204+
llama_cpp.CreateCompletionResponse,
205+
str,
206+
],
207+
responses={
208+
"200": {
209+
"description": "Successful Response",
210+
"content": {
211+
"application/json": {
212+
"schema": {
213+
"anyOf": [
214+
{"$ref": "#/components/schemas/CreateCompletionResponse"}
215+
],
216+
"title": "Completion response, when stream=False",
217+
}
218+
},
219+
"text/event-stream":{
220+
"schema": {
221+
"type": "string",
222+
"title": "Server Side Streaming response, when stream=True. " +
223+
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
224+
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
225+
}
226+
}
227+
},
228+
}
229+
},
201230
)
202231
@router.post(
203232
"/v1/engines/copilot-codex/completions",
@@ -280,7 +309,33 @@ async def create_embedding(
280309

281310

282311
@router.post(
283-
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)]
312+
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)],
313+
response_model= Union[
314+
llama_cpp.ChatCompletion, str
315+
],
316+
responses={
317+
"200": {
318+
"description": "Successful Response",
319+
"content": {
320+
"application/json": {
321+
"schema": {
322+
"anyOf": [
323+
{"$ref": "#/components/schemas/CreateChatCompletionResponse"}
324+
],
325+
"title": "Completion response, when stream=False",
326+
}
327+
},
328+
"text/event-stream":{
329+
"schema": {
330+
"type": "string",
331+
"title": "Server Side Streaming response, when stream=True" +
332+
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
333+
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
334+
}
335+
}
336+
},
337+
}
338+
},
284339
)
285340
async def create_chat_completion(
286341
request: Request,

0 commit comments

Comments
 (0)