|
24 | 24 | from . import llama_cpp
|
25 | 25 | from .llama_types import *
|
26 | 26 | from .llama_grammar import LlamaGrammar
|
| 27 | +from . import llama_chat_format |
27 | 28 |
|
28 | 29 | import numpy as np
|
29 | 30 | import numpy.typing as npt
|
@@ -243,6 +244,8 @@ def __init__(
|
243 | 244 | lora_path: Optional[str] = None,
|
244 | 245 | # Backend Params
|
245 | 246 | numa: bool = False,
|
| 247 | + # Chat Format Params |
| 248 | + chat_format: str = "llama-2", |
246 | 249 | # Misc
|
247 | 250 | verbose: bool = True,
|
248 | 251 | # Extra Params
|
@@ -273,6 +276,7 @@ def __init__(
|
273 | 276 | lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
274 | 277 | lora_path: Path to a LoRA file to apply to the model.
|
275 | 278 | numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
|
| 279 | + chat_format: String specifying the chat format to use when calling create_chat_completion. |
276 | 280 | verbose: Print verbose output to stderr.
|
277 | 281 | kwargs: Unused keyword arguments (for additional backwards compatibility).
|
278 | 282 |
|
@@ -387,6 +391,8 @@ def __init__(
|
387 | 391 |
|
388 | 392 | if self.verbose:
|
389 | 393 | print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
| 394 | + |
| 395 | + self.chat_format = chat_format |
390 | 396 |
|
391 | 397 | self._n_vocab = self.n_vocab()
|
392 | 398 | self._n_ctx = self.n_ctx()
|
@@ -1578,7 +1584,7 @@ def _convert_completion_to_chat(
|
1578 | 1584 |
|
1579 | 1585 | def create_chat_completion(
|
1580 | 1586 | self,
|
1581 |
| - messages: List[ChatCompletionMessage], |
| 1587 | + messages: List[ChatCompletionRequestMessage], |
1582 | 1588 | functions: Optional[List[ChatCompletionFunction]] = None,
|
1583 | 1589 | function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
|
1584 | 1590 | temperature: float = 0.2,
|
@@ -1613,11 +1619,19 @@ def create_chat_completion(
|
1613 | 1619 | Returns:
|
1614 | 1620 | Generated chat completion or a stream of chat completion chunks.
|
1615 | 1621 | """
|
1616 |
| - completion_or_chunks = self.chat_completion_template.create_chat_completion( |
1617 |
| - self, |
| 1622 | + |
| 1623 | + format = llama_chat_format.get_chat_format(self.chat_format) |
| 1624 | + result = format( |
1618 | 1625 | messages=messages,
|
1619 |
| - functions=functions, |
1620 |
| - function_call=function_call, |
| 1626 | + ) |
| 1627 | + prompt = result.prompt |
| 1628 | + if result.stop is not None: |
| 1629 | + stop = [] if stop is None else [stop] if isinstance(stop, str) else stop |
| 1630 | + rstop = result.stop if isinstance(result.stop, list) else [result.stop] |
| 1631 | + stop = stop + rstop |
| 1632 | + |
| 1633 | + completion_or_chunks = self.create_completion( |
| 1634 | + prompt=prompt, |
1621 | 1635 | temperature=temperature,
|
1622 | 1636 | top_p=top_p,
|
1623 | 1637 | top_k=top_k,
|
@@ -1675,6 +1689,8 @@ def __getstate__(self):
|
1675 | 1689 | lora_path=self.lora_path,
|
1676 | 1690 | # Backend Params
|
1677 | 1691 | numa=self.numa,
|
| 1692 | + # Chat Format Params |
| 1693 | + chat_format=self.chat_format, |
1678 | 1694 | # Misc
|
1679 | 1695 | verbose=self.verbose,
|
1680 | 1696 | )
|
@@ -1708,6 +1724,8 @@ def __setstate__(self, state):
|
1708 | 1724 | lora_path=state["lora_path"],
|
1709 | 1725 | # Backend Params
|
1710 | 1726 | numa=state["numa"],
|
| 1727 | + # Chat Format Params |
| 1728 | + chat_format=state["chat_format"], |
1711 | 1729 | # Misc
|
1712 | 1730 | verbose=state["verbose"],
|
1713 | 1731 | )
|
@@ -1821,89 +1839,3 @@ def decode(self, tokens: List[int]) -> str:
|
1821 | 1839 | @classmethod
|
1822 | 1840 | def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
1823 | 1841 | return cls(Llama(model_path=path, vocab_only=True))
|
1824 |
| - |
1825 |
| - |
1826 |
| -class ChatCompletionFormat(ABC): |
1827 |
| - """Base class for chat completion templates.""" |
1828 |
| - |
1829 |
| - @abstractmethod |
1830 |
| - def create_chat_completion( |
1831 |
| - self, |
1832 |
| - llama: Llama, |
1833 |
| - messages: List[ChatCompletionMessage], |
1834 |
| - functions: Optional[List[ChatCompletionFunction]] = None, |
1835 |
| - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, |
1836 |
| - temperature: float = 0.2, |
1837 |
| - top_p: float = 0.95, |
1838 |
| - top_k: int = 40, |
1839 |
| - stream: bool = False, |
1840 |
| - stop: Optional[Union[str, List[str]]] = [], |
1841 |
| - max_tokens: int = 256, |
1842 |
| - presence_penalty: float = 0.0, |
1843 |
| - frequency_penalty: float = 0.0, |
1844 |
| - repeat_penalty: float = 1.1, |
1845 |
| - tfs_z: float = 1.0, |
1846 |
| - mirostat_mode: int = 0, |
1847 |
| - mirostat_tau: float = 5.0, |
1848 |
| - mirostat_eta: float = 0.1, |
1849 |
| - model: Optional[str] = None, |
1850 |
| - logits_processor: Optional[LogitsProcessorList] = None, |
1851 |
| - grammar: Optional[LlamaGrammar] = None, |
1852 |
| - ) -> Union[Completion, Iterator[CompletionChunk]]: |
1853 |
| - raise NotImplementedError |
1854 |
| - |
1855 |
| - |
1856 |
| -class DefaultChatCompletionFormat(ABC): |
1857 |
| - """Base class for chat completion templates.""" |
1858 |
| - |
1859 |
| - def create_chat_completion( |
1860 |
| - self, |
1861 |
| - llama: Llama, |
1862 |
| - messages: List[ChatCompletionMessage], |
1863 |
| - functions: Optional[List[ChatCompletionFunction]] = None, |
1864 |
| - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, |
1865 |
| - temperature: float = 0.2, |
1866 |
| - top_p: float = 0.95, |
1867 |
| - top_k: int = 40, |
1868 |
| - stream: bool = False, |
1869 |
| - stop: Optional[Union[str, List[str]]] = [], |
1870 |
| - max_tokens: int = 256, |
1871 |
| - presence_penalty: float = 0.0, |
1872 |
| - frequency_penalty: float = 0.0, |
1873 |
| - repeat_penalty: float = 1.1, |
1874 |
| - tfs_z: float = 1.0, |
1875 |
| - mirostat_mode: int = 0, |
1876 |
| - mirostat_tau: float = 5.0, |
1877 |
| - mirostat_eta: float = 0.1, |
1878 |
| - model: Optional[str] = None, |
1879 |
| - logits_processor: Optional[LogitsProcessorList] = None, |
1880 |
| - grammar: Optional[LlamaGrammar] = None, |
1881 |
| - ) -> Union[Completion, Iterator[CompletionChunk]]: |
1882 |
| - stop = ( |
1883 |
| - stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] |
1884 |
| - ) |
1885 |
| - chat_history = "".join( |
1886 |
| - f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' |
1887 |
| - for message in messages |
1888 |
| - ) |
1889 |
| - PROMPT = chat_history + "### Assistant:" |
1890 |
| - PROMPT_STOP = ["### Assistant:", "### Human:"] |
1891 |
| - return llama.create_completion( |
1892 |
| - prompt=PROMPT, |
1893 |
| - stop=PROMPT_STOP + stop, |
1894 |
| - temperature=temperature, |
1895 |
| - top_p=top_p, |
1896 |
| - top_k=top_k, |
1897 |
| - stream=stream, |
1898 |
| - max_tokens=max_tokens, |
1899 |
| - repeat_penalty=repeat_penalty, |
1900 |
| - presence_penalty=presence_penalty, |
1901 |
| - frequency_penalty=frequency_penalty, |
1902 |
| - tfs_z=tfs_z, |
1903 |
| - mirostat_mode=mirostat_mode, |
1904 |
| - mirostat_tau=mirostat_tau, |
1905 |
| - mirostat_eta=mirostat_eta, |
1906 |
| - model=model, |
1907 |
| - logits_processor=logits_processor, |
1908 |
| - grammar=grammar, |
1909 |
| - ) |
0 commit comments