Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ae47e4f

Browse files
committed
Add chat format
1 parent 9c68382 commit ae47e4f

File tree

2 files changed

+315
-91
lines changed

2 files changed

+315
-91
lines changed

llama_cpp/llama.py

+23-91
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import llama_cpp
2525
from .llama_types import *
2626
from .llama_grammar import LlamaGrammar
27+
from . import llama_chat_format
2728

2829
import numpy as np
2930
import numpy.typing as npt
@@ -243,6 +244,8 @@ def __init__(
243244
lora_path: Optional[str] = None,
244245
# Backend Params
245246
numa: bool = False,
247+
# Chat Format Params
248+
chat_format: str = "llama-2",
246249
# Misc
247250
verbose: bool = True,
248251
# Extra Params
@@ -273,6 +276,7 @@ def __init__(
273276
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
274277
lora_path: Path to a LoRA file to apply to the model.
275278
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
279+
chat_format: String specifying the chat format to use when calling create_chat_completion.
276280
verbose: Print verbose output to stderr.
277281
kwargs: Unused keyword arguments (for additional backwards compatibility).
278282
@@ -387,6 +391,8 @@ def __init__(
387391

388392
if self.verbose:
389393
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
394+
395+
self.chat_format = chat_format
390396

391397
self._n_vocab = self.n_vocab()
392398
self._n_ctx = self.n_ctx()
@@ -1578,7 +1584,7 @@ def _convert_completion_to_chat(
15781584

15791585
def create_chat_completion(
15801586
self,
1581-
messages: List[ChatCompletionMessage],
1587+
messages: List[ChatCompletionRequestMessage],
15821588
functions: Optional[List[ChatCompletionFunction]] = None,
15831589
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
15841590
temperature: float = 0.2,
@@ -1613,11 +1619,19 @@ def create_chat_completion(
16131619
Returns:
16141620
Generated chat completion or a stream of chat completion chunks.
16151621
"""
1616-
completion_or_chunks = self.chat_completion_template.create_chat_completion(
1617-
self,
1622+
1623+
format = llama_chat_format.get_chat_format(self.chat_format)
1624+
result = format(
16181625
messages=messages,
1619-
functions=functions,
1620-
function_call=function_call,
1626+
)
1627+
prompt = result.prompt
1628+
if result.stop is not None:
1629+
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
1630+
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
1631+
stop = stop + rstop
1632+
1633+
completion_or_chunks = self.create_completion(
1634+
prompt=prompt,
16211635
temperature=temperature,
16221636
top_p=top_p,
16231637
top_k=top_k,
@@ -1675,6 +1689,8 @@ def __getstate__(self):
16751689
lora_path=self.lora_path,
16761690
# Backend Params
16771691
numa=self.numa,
1692+
# Chat Format Params
1693+
chat_format=self.chat_format,
16781694
# Misc
16791695
verbose=self.verbose,
16801696
)
@@ -1708,6 +1724,8 @@ def __setstate__(self, state):
17081724
lora_path=state["lora_path"],
17091725
# Backend Params
17101726
numa=state["numa"],
1727+
# Chat Format Params
1728+
chat_format=state["chat_format"],
17111729
# Misc
17121730
verbose=state["verbose"],
17131731
)
@@ -1821,89 +1839,3 @@ def decode(self, tokens: List[int]) -> str:
18211839
@classmethod
18221840
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
18231841
return cls(Llama(model_path=path, vocab_only=True))
1824-
1825-
1826-
class ChatCompletionFormat(ABC):
1827-
"""Base class for chat completion templates."""
1828-
1829-
@abstractmethod
1830-
def create_chat_completion(
1831-
self,
1832-
llama: Llama,
1833-
messages: List[ChatCompletionMessage],
1834-
functions: Optional[List[ChatCompletionFunction]] = None,
1835-
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
1836-
temperature: float = 0.2,
1837-
top_p: float = 0.95,
1838-
top_k: int = 40,
1839-
stream: bool = False,
1840-
stop: Optional[Union[str, List[str]]] = [],
1841-
max_tokens: int = 256,
1842-
presence_penalty: float = 0.0,
1843-
frequency_penalty: float = 0.0,
1844-
repeat_penalty: float = 1.1,
1845-
tfs_z: float = 1.0,
1846-
mirostat_mode: int = 0,
1847-
mirostat_tau: float = 5.0,
1848-
mirostat_eta: float = 0.1,
1849-
model: Optional[str] = None,
1850-
logits_processor: Optional[LogitsProcessorList] = None,
1851-
grammar: Optional[LlamaGrammar] = None,
1852-
) -> Union[Completion, Iterator[CompletionChunk]]:
1853-
raise NotImplementedError
1854-
1855-
1856-
class DefaultChatCompletionFormat(ABC):
1857-
"""Base class for chat completion templates."""
1858-
1859-
def create_chat_completion(
1860-
self,
1861-
llama: Llama,
1862-
messages: List[ChatCompletionMessage],
1863-
functions: Optional[List[ChatCompletionFunction]] = None,
1864-
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
1865-
temperature: float = 0.2,
1866-
top_p: float = 0.95,
1867-
top_k: int = 40,
1868-
stream: bool = False,
1869-
stop: Optional[Union[str, List[str]]] = [],
1870-
max_tokens: int = 256,
1871-
presence_penalty: float = 0.0,
1872-
frequency_penalty: float = 0.0,
1873-
repeat_penalty: float = 1.1,
1874-
tfs_z: float = 1.0,
1875-
mirostat_mode: int = 0,
1876-
mirostat_tau: float = 5.0,
1877-
mirostat_eta: float = 0.1,
1878-
model: Optional[str] = None,
1879-
logits_processor: Optional[LogitsProcessorList] = None,
1880-
grammar: Optional[LlamaGrammar] = None,
1881-
) -> Union[Completion, Iterator[CompletionChunk]]:
1882-
stop = (
1883-
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1884-
)
1885-
chat_history = "".join(
1886-
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
1887-
for message in messages
1888-
)
1889-
PROMPT = chat_history + "### Assistant:"
1890-
PROMPT_STOP = ["### Assistant:", "### Human:"]
1891-
return llama.create_completion(
1892-
prompt=PROMPT,
1893-
stop=PROMPT_STOP + stop,
1894-
temperature=temperature,
1895-
top_p=top_p,
1896-
top_k=top_k,
1897-
stream=stream,
1898-
max_tokens=max_tokens,
1899-
repeat_penalty=repeat_penalty,
1900-
presence_penalty=presence_penalty,
1901-
frequency_penalty=frequency_penalty,
1902-
tfs_z=tfs_z,
1903-
mirostat_mode=mirostat_mode,
1904-
mirostat_tau=mirostat_tau,
1905-
mirostat_eta=mirostat_eta,
1906-
model=model,
1907-
logits_processor=logits_processor,
1908-
grammar=grammar,
1909-
)

0 commit comments

Comments
 (0)