Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b30b9c3

Browse files
committed
Add JSON mode support. Closes abetlen#881
1 parent 4852a6a commit b30b9c3

File tree

4 files changed

+116
-39
lines changed

4 files changed

+116
-39
lines changed

llama_cpp/llama.py

+2
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,7 @@ def create_chat_completion(
19011901
stream: bool = False,
19021902
stop: Optional[Union[str, List[str]]] = [],
19031903
seed: Optional[int] = None,
1904+
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
19041905
max_tokens: int = 256,
19051906
presence_penalty: float = 0.0,
19061907
frequency_penalty: float = 0.0,
@@ -1946,6 +1947,7 @@ def create_chat_completion(
19461947
stream=stream,
19471948
stop=stop,
19481949
seed=seed,
1950+
response_format=response_format,
19491951
max_tokens=max_tokens,
19501952
presence_penalty=presence_penalty,
19511953
frequency_penalty=frequency_penalty,

llama_cpp/llama_chat_format.py

+106-38
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import dataclasses
66
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
77

8-
import llama_cpp.llama_types as llama_types
98
import llama_cpp.llama as llama
9+
import llama_cpp.llama_types as llama_types
10+
import llama_cpp.llama_grammar as llama_grammar
1011

1112

1213
class LlamaChatCompletionHandler(Protocol):
@@ -25,6 +26,9 @@ def __call__(
2526
stream: bool = False,
2627
stop: Optional[Union[str, List[str]]] = [],
2728
seed: Optional[int] = None,
29+
response_format: Optional[
30+
llama_types.ChatCompletionRequestResponseFormat
31+
] = None,
2832
max_tokens: int = 256,
2933
presence_penalty: float = 0.0,
3034
frequency_penalty: float = 0.0,
@@ -37,7 +41,10 @@ def __call__(
3741
logits_processor: Optional[llama.LogitsProcessorList] = None,
3842
grammar: Optional[llama.LlamaGrammar] = None,
3943
**kwargs, # type: ignore
40-
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
44+
) -> Union[
45+
llama_types.CreateChatCompletionResponse,
46+
Iterator[llama_types.CreateChatCompletionStreamResponse],
47+
]:
4148
...
4249

4350

@@ -169,6 +176,7 @@ class ChatFormatterResponse:
169176
class ChatFormatter(Protocol):
170177
def __call__(
171178
self,
179+
*,
172180
messages: List[llama_types.ChatCompletionRequestMessage],
173181
**kwargs: Any,
174182
) -> ChatFormatterResponse:
@@ -264,17 +272,24 @@ def _convert_completion_to_chat(
264272
def register_chat_format(name: str):
265273
def decorator(f: ChatFormatter):
266274
def basic_create_chat_completion(
275+
*,
267276
llama: llama.Llama,
268277
messages: List[llama_types.ChatCompletionRequestMessage],
269278
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
270279
function_call: Optional[
271-
Union[str, llama_types.ChatCompletionFunctionCall]
280+
llama_types.ChatCompletionRequestFunctionCall
272281
] = None,
282+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
283+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
273284
temperature: float = 0.2,
274285
top_p: float = 0.95,
275286
top_k: int = 40,
276287
stream: bool = False,
277288
stop: Optional[Union[str, List[str]]] = [],
289+
seed: Optional[int] = None,
290+
response_format: Optional[
291+
llama_types.ChatCompletionRequestResponseFormat
292+
] = None,
278293
max_tokens: int = 256,
279294
presence_penalty: float = 0.0,
280295
frequency_penalty: float = 0.0,
@@ -286,8 +301,10 @@ def basic_create_chat_completion(
286301
model: Optional[str] = None,
287302
logits_processor: Optional[llama.LogitsProcessorList] = None,
288303
grammar: Optional[llama.LlamaGrammar] = None,
304+
**kwargs, # type: ignore
289305
) -> Union[
290-
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
306+
llama_types.CreateChatCompletionResponse,
307+
Iterator[llama_types.CreateChatCompletionStreamResponse],
291308
]:
292309
result = f(
293310
messages=messages,
@@ -299,6 +316,10 @@ def basic_create_chat_completion(
299316
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
300317
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
301318
stop = stop + rstop
319+
320+
if response_format is not None and response_format["type"] == "json_object":
321+
print("hello world")
322+
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
302323

303324
completion_or_chunks = llama.create_completion(
304325
prompt=prompt,
@@ -307,6 +328,7 @@ def basic_create_chat_completion(
307328
top_k=top_k,
308329
stream=stream,
309330
stop=stop,
331+
seed=seed,
310332
max_tokens=max_tokens,
311333
presence_penalty=presence_penalty,
312334
frequency_penalty=frequency_penalty,
@@ -319,7 +341,7 @@ def basic_create_chat_completion(
319341
logits_processor=logits_processor,
320342
grammar=grammar,
321343
)
322-
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
344+
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
323345

324346
register_chat_completion_handler(name)(basic_create_chat_completion)
325347
return f
@@ -727,7 +749,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
727749

728750
assert "usage" in completion
729751
assert isinstance(function_call, str)
730-
assert stream is False # TODO: support stream mode
752+
assert stream is False # TODO: support stream mode
731753

732754
return llama_types.CreateChatCompletionResponse(
733755
id="chat" + completion["id"],
@@ -759,7 +781,9 @@ def __init__(self, clip_model_path: str):
759781
self._llava_cpp = llava_cpp
760782
self.clip_model_path = clip_model_path
761783

762-
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0)
784+
self.clip_ctx = self._llava_cpp.clip_model_load(
785+
self.clip_model_path.encode(), 0
786+
)
763787

764788
def __del__(self):
765789
if self.clip_ctx is not None:
@@ -805,64 +829,108 @@ def __call__(
805829
logits_processor: Optional[llama.LogitsProcessorList] = None,
806830
grammar: Optional[llama.LlamaGrammar] = None,
807831
**kwargs, # type: ignore
808-
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
809-
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava
832+
) -> Union[
833+
llama_types.CreateChatCompletionResponse,
834+
Iterator[llama_types.CreateChatCompletionStreamResponse],
835+
]:
836+
assert (
837+
llama.context_params.logits_all is True
838+
) # BUG: logits_all=True is required for llava
810839
assert self.clip_ctx is not None
811840
system_prompt = _get_system_message(messages)
812-
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
813-
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
841+
system_prompt = (
842+
system_prompt
843+
if system_prompt != ""
844+
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
845+
)
846+
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
814847
user_role = "\nUSER:"
815848
assistant_role = "\nASSISTANT:"
816849
llama.reset()
817850
llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True))
818851
for message in messages:
819852
if message["role"] == "user" and message["content"] is not None:
820853
if isinstance(message["content"], str):
821-
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False))
854+
llama.eval(
855+
llama.tokenize(
856+
f"{user_role} {message['content']}".encode("utf8"),
857+
add_bos=False,
858+
)
859+
)
822860
else:
823861
assert isinstance(message["content"], list)
824-
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False))
862+
llama.eval(
863+
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
864+
)
825865
for content in message["content"]:
826866
if content["type"] == "text":
827-
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False))
867+
llama.eval(
868+
llama.tokenize(
869+
f"{content['text']}".encode("utf8"), add_bos=False
870+
)
871+
)
828872
if content["type"] == "image_url":
829-
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"])
873+
image_bytes = (
874+
self.load_image(content["image_url"]["url"])
875+
if isinstance(content["image_url"], dict)
876+
else self.load_image(content["image_url"])
877+
)
830878
import array
831-
data_array = array.array('B', image_bytes)
832-
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)
833-
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes))
879+
880+
data_array = array.array("B", image_bytes)
881+
c_ubyte_ptr = (
882+
ctypes.c_ubyte * len(data_array)
883+
).from_buffer(data_array)
884+
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
885+
ctx_clip=self.clip_ctx,
886+
n_threads=llama.context_params.n_threads,
887+
image_bytes=c_ubyte_ptr,
888+
image_bytes_length=len(image_bytes),
889+
)
834890
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
835891
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
836892
try:
837893
n_past = ctypes.c_int(llama.n_tokens)
838894
n_past_p = ctypes.pointer(n_past)
839-
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p)
895+
self._llava_cpp.llava_eval_image_embed(
896+
ctx_llama=llama.ctx,
897+
embed=embed,
898+
n_batch=llama.n_batch,
899+
n_past=n_past_p,
900+
)
840901
assert llama.n_ctx() >= n_past.value
841902
llama.n_tokens = n_past.value
842903
finally:
843904
self._llava_cpp.llava_image_embed_free(embed)
844905
if message["role"] == "assistant" and message["content"] is not None:
845-
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False))
906+
llama.eval(
907+
llama.tokenize(
908+
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
909+
)
910+
)
846911
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
847912

848913
prompt = llama._input_ids.tolist()
849914

850-
return _convert_completion_to_chat(llama.create_completion(
851-
prompt=prompt,
852-
temperature=temperature,
853-
top_p=top_p,
854-
top_k=top_k,
915+
return _convert_completion_to_chat(
916+
llama.create_completion(
917+
prompt=prompt,
918+
temperature=temperature,
919+
top_p=top_p,
920+
top_k=top_k,
921+
stream=stream,
922+
stop=stop,
923+
max_tokens=max_tokens,
924+
presence_penalty=presence_penalty,
925+
frequency_penalty=frequency_penalty,
926+
repeat_penalty=repeat_penalty,
927+
tfs_z=tfs_z,
928+
mirostat_mode=mirostat_mode,
929+
mirostat_tau=mirostat_tau,
930+
mirostat_eta=mirostat_eta,
931+
model=model,
932+
logits_processor=logits_processor,
933+
grammar=grammar,
934+
),
855935
stream=stream,
856-
stop=stop,
857-
max_tokens=max_tokens,
858-
presence_penalty=presence_penalty,
859-
frequency_penalty=frequency_penalty,
860-
repeat_penalty=repeat_penalty,
861-
tfs_z=tfs_z,
862-
mirostat_mode=mirostat_mode,
863-
mirostat_tau=mirostat_tau,
864-
mirostat_eta=mirostat_eta,
865-
model=model,
866-
logits_processor=logits_processor,
867-
grammar=grammar,
868-
), stream=stream)
936+
)

llama_cpp/llama_types.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict):
152152
name: str
153153

154154

155+
class ChatCompletionRequestResponseFormat(TypedDict):
156+
type: Literal["text", "json_object"]
157+
158+
155159
class ChatCompletionRequestMessageContentPartText(TypedDict):
156160
type: Literal["text"]
157161
text: str
@@ -241,7 +245,7 @@ class ChatCompletionRequestFunctionCallOption(TypedDict):
241245
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
242246
]
243247

244-
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
248+
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
245249

246250

247251
class ChatCompletionToolFunction(TypedDict):

llama_cpp/server/app.py

+3
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel):
792792
frequency_penalty: Optional[float] = frequency_penalty_field
793793
logit_bias: Optional[Dict[str, float]] = Field(None)
794794
seed: Optional[int] = Field(None)
795+
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
796+
default=None,
797+
)
795798

796799
# ignored or currently unsupported
797800
model: Optional[str] = model_field

0 commit comments

Comments
 (0)