5
5
import dataclasses
6
6
from typing import Any , Dict , Iterator , List , Optional , Tuple , Union , Protocol
7
7
8
- import llama_cpp .llama_types as llama_types
9
8
import llama_cpp .llama as llama
9
+ import llama_cpp .llama_types as llama_types
10
+ import llama_cpp .llama_grammar as llama_grammar
10
11
11
12
12
13
class LlamaChatCompletionHandler (Protocol ):
@@ -25,6 +26,9 @@ def __call__(
25
26
stream : bool = False ,
26
27
stop : Optional [Union [str , List [str ]]] = [],
27
28
seed : Optional [int ] = None ,
29
+ response_format : Optional [
30
+ llama_types .ChatCompletionRequestResponseFormat
31
+ ] = None ,
28
32
max_tokens : int = 256 ,
29
33
presence_penalty : float = 0.0 ,
30
34
frequency_penalty : float = 0.0 ,
@@ -37,7 +41,10 @@ def __call__(
37
41
logits_processor : Optional [llama .LogitsProcessorList ] = None ,
38
42
grammar : Optional [llama .LlamaGrammar ] = None ,
39
43
** kwargs , # type: ignore
40
- ) -> Union [llama_types .CreateChatCompletionResponse , Iterator [llama_types .CreateChatCompletionStreamResponse ]]:
44
+ ) -> Union [
45
+ llama_types .CreateChatCompletionResponse ,
46
+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
47
+ ]:
41
48
...
42
49
43
50
@@ -169,6 +176,7 @@ class ChatFormatterResponse:
169
176
class ChatFormatter (Protocol ):
170
177
def __call__ (
171
178
self ,
179
+ * ,
172
180
messages : List [llama_types .ChatCompletionRequestMessage ],
173
181
** kwargs : Any ,
174
182
) -> ChatFormatterResponse :
@@ -264,17 +272,24 @@ def _convert_completion_to_chat(
264
272
def register_chat_format (name : str ):
265
273
def decorator (f : ChatFormatter ):
266
274
def basic_create_chat_completion (
275
+ * ,
267
276
llama : llama .Llama ,
268
277
messages : List [llama_types .ChatCompletionRequestMessage ],
269
278
functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
270
279
function_call : Optional [
271
- Union [ str , llama_types .ChatCompletionFunctionCall ]
280
+ llama_types .ChatCompletionRequestFunctionCall
272
281
] = None ,
282
+ tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
283
+ tool_choice : Optional [llama_types .ChatCompletionToolChoiceOption ] = None ,
273
284
temperature : float = 0.2 ,
274
285
top_p : float = 0.95 ,
275
286
top_k : int = 40 ,
276
287
stream : bool = False ,
277
288
stop : Optional [Union [str , List [str ]]] = [],
289
+ seed : Optional [int ] = None ,
290
+ response_format : Optional [
291
+ llama_types .ChatCompletionRequestResponseFormat
292
+ ] = None ,
278
293
max_tokens : int = 256 ,
279
294
presence_penalty : float = 0.0 ,
280
295
frequency_penalty : float = 0.0 ,
@@ -286,8 +301,10 @@ def basic_create_chat_completion(
286
301
model : Optional [str ] = None ,
287
302
logits_processor : Optional [llama .LogitsProcessorList ] = None ,
288
303
grammar : Optional [llama .LlamaGrammar ] = None ,
304
+ ** kwargs , # type: ignore
289
305
) -> Union [
290
- llama_types .ChatCompletion , Iterator [llama_types .ChatCompletionChunk ]
306
+ llama_types .CreateChatCompletionResponse ,
307
+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
291
308
]:
292
309
result = f (
293
310
messages = messages ,
@@ -299,6 +316,10 @@ def basic_create_chat_completion(
299
316
stop = [] if stop is None else [stop ] if isinstance (stop , str ) else stop
300
317
rstop = result .stop if isinstance (result .stop , list ) else [result .stop ]
301
318
stop = stop + rstop
319
+
320
+ if response_format is not None and response_format ["type" ] == "json_object" :
321
+ print ("hello world" )
322
+ grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
302
323
303
324
completion_or_chunks = llama .create_completion (
304
325
prompt = prompt ,
@@ -307,6 +328,7 @@ def basic_create_chat_completion(
307
328
top_k = top_k ,
308
329
stream = stream ,
309
330
stop = stop ,
331
+ seed = seed ,
310
332
max_tokens = max_tokens ,
311
333
presence_penalty = presence_penalty ,
312
334
frequency_penalty = frequency_penalty ,
@@ -319,7 +341,7 @@ def basic_create_chat_completion(
319
341
logits_processor = logits_processor ,
320
342
grammar = grammar ,
321
343
)
322
- return _convert_completion_to_chat (completion_or_chunks , stream = stream ) # type: ignore
344
+ return _convert_completion_to_chat (completion_or_chunks , stream = stream )
323
345
324
346
register_chat_completion_handler (name )(basic_create_chat_completion )
325
347
return f
@@ -727,7 +749,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
727
749
728
750
assert "usage" in completion
729
751
assert isinstance (function_call , str )
730
- assert stream is False # TODO: support stream mode
752
+ assert stream is False # TODO: support stream mode
731
753
732
754
return llama_types .CreateChatCompletionResponse (
733
755
id = "chat" + completion ["id" ],
@@ -759,7 +781,9 @@ def __init__(self, clip_model_path: str):
759
781
self ._llava_cpp = llava_cpp
760
782
self .clip_model_path = clip_model_path
761
783
762
- self .clip_ctx = self ._llava_cpp .clip_model_load (self .clip_model_path .encode (), 0 )
784
+ self .clip_ctx = self ._llava_cpp .clip_model_load (
785
+ self .clip_model_path .encode (), 0
786
+ )
763
787
764
788
def __del__ (self ):
765
789
if self .clip_ctx is not None :
@@ -805,64 +829,108 @@ def __call__(
805
829
logits_processor : Optional [llama .LogitsProcessorList ] = None ,
806
830
grammar : Optional [llama .LlamaGrammar ] = None ,
807
831
** kwargs , # type: ignore
808
- ) -> Union [llama_types .CreateChatCompletionResponse , Iterator [llama_types .CreateChatCompletionStreamResponse ]]:
809
- assert llama .context_params .logits_all is True # BUG: logits_all=True is required for llava
832
+ ) -> Union [
833
+ llama_types .CreateChatCompletionResponse ,
834
+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
835
+ ]:
836
+ assert (
837
+ llama .context_params .logits_all is True
838
+ ) # BUG: logits_all=True is required for llava
810
839
assert self .clip_ctx is not None
811
840
system_prompt = _get_system_message (messages )
812
- system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
813
- system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
841
+ system_prompt = (
842
+ system_prompt
843
+ if system_prompt != ""
844
+ else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
845
+ )
846
+ system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
814
847
user_role = "\n USER:"
815
848
assistant_role = "\n ASSISTANT:"
816
849
llama .reset ()
817
850
llama .eval (llama .tokenize (system_prompt .encode ("utf8" ), add_bos = True ))
818
851
for message in messages :
819
852
if message ["role" ] == "user" and message ["content" ] is not None :
820
853
if isinstance (message ["content" ], str ):
821
- llama .eval (llama .tokenize (f"{ user_role } { message ['content' ]} " .encode ("utf8" ), add_bos = False ))
854
+ llama .eval (
855
+ llama .tokenize (
856
+ f"{ user_role } { message ['content' ]} " .encode ("utf8" ),
857
+ add_bos = False ,
858
+ )
859
+ )
822
860
else :
823
861
assert isinstance (message ["content" ], list )
824
- llama .eval (llama .tokenize (f"{ user_role } " .encode ("utf8" ), add_bos = False ))
862
+ llama .eval (
863
+ llama .tokenize (f"{ user_role } " .encode ("utf8" ), add_bos = False )
864
+ )
825
865
for content in message ["content" ]:
826
866
if content ["type" ] == "text" :
827
- llama .eval (llama .tokenize (f"{ content ['text' ]} " .encode ("utf8" ), add_bos = False ))
867
+ llama .eval (
868
+ llama .tokenize (
869
+ f"{ content ['text' ]} " .encode ("utf8" ), add_bos = False
870
+ )
871
+ )
828
872
if content ["type" ] == "image_url" :
829
- image_bytes = self .load_image (content ["image_url" ]["url" ]) if isinstance (content ["image_url" ], dict ) else self .load_image (content ["image_url" ])
873
+ image_bytes = (
874
+ self .load_image (content ["image_url" ]["url" ])
875
+ if isinstance (content ["image_url" ], dict )
876
+ else self .load_image (content ["image_url" ])
877
+ )
830
878
import array
831
- data_array = array .array ('B' , image_bytes )
832
- c_ubyte_ptr = (ctypes .c_ubyte * len (data_array )).from_buffer (data_array )
833
- embed = self ._llava_cpp .llava_image_embed_make_with_bytes (ctx_clip = self .clip_ctx , n_threads = llama .context_params .n_threads , image_bytes = c_ubyte_ptr , image_bytes_length = len (image_bytes ))
879
+
880
+ data_array = array .array ("B" , image_bytes )
881
+ c_ubyte_ptr = (
882
+ ctypes .c_ubyte * len (data_array )
883
+ ).from_buffer (data_array )
884
+ embed = self ._llava_cpp .llava_image_embed_make_with_bytes (
885
+ ctx_clip = self .clip_ctx ,
886
+ n_threads = llama .context_params .n_threads ,
887
+ image_bytes = c_ubyte_ptr ,
888
+ image_bytes_length = len (image_bytes ),
889
+ )
834
890
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
835
891
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
836
892
try :
837
893
n_past = ctypes .c_int (llama .n_tokens )
838
894
n_past_p = ctypes .pointer (n_past )
839
- self ._llava_cpp .llava_eval_image_embed (ctx_llama = llama .ctx , embed = embed , n_batch = llama .n_batch , n_past = n_past_p )
895
+ self ._llava_cpp .llava_eval_image_embed (
896
+ ctx_llama = llama .ctx ,
897
+ embed = embed ,
898
+ n_batch = llama .n_batch ,
899
+ n_past = n_past_p ,
900
+ )
840
901
assert llama .n_ctx () >= n_past .value
841
902
llama .n_tokens = n_past .value
842
903
finally :
843
904
self ._llava_cpp .llava_image_embed_free (embed )
844
905
if message ["role" ] == "assistant" and message ["content" ] is not None :
845
- llama .eval (llama .tokenize (f"ASSISTANT: { message ['content' ]} " .encode ("utf8" ), add_bos = False ))
906
+ llama .eval (
907
+ llama .tokenize (
908
+ f"ASSISTANT: { message ['content' ]} " .encode ("utf8" ), add_bos = False
909
+ )
910
+ )
846
911
llama .eval (llama .tokenize (f"{ assistant_role } " .encode ("utf8" ), add_bos = False ))
847
912
848
913
prompt = llama ._input_ids .tolist ()
849
914
850
- return _convert_completion_to_chat (llama .create_completion (
851
- prompt = prompt ,
852
- temperature = temperature ,
853
- top_p = top_p ,
854
- top_k = top_k ,
915
+ return _convert_completion_to_chat (
916
+ llama .create_completion (
917
+ prompt = prompt ,
918
+ temperature = temperature ,
919
+ top_p = top_p ,
920
+ top_k = top_k ,
921
+ stream = stream ,
922
+ stop = stop ,
923
+ max_tokens = max_tokens ,
924
+ presence_penalty = presence_penalty ,
925
+ frequency_penalty = frequency_penalty ,
926
+ repeat_penalty = repeat_penalty ,
927
+ tfs_z = tfs_z ,
928
+ mirostat_mode = mirostat_mode ,
929
+ mirostat_tau = mirostat_tau ,
930
+ mirostat_eta = mirostat_eta ,
931
+ model = model ,
932
+ logits_processor = logits_processor ,
933
+ grammar = grammar ,
934
+ ),
855
935
stream = stream ,
856
- stop = stop ,
857
- max_tokens = max_tokens ,
858
- presence_penalty = presence_penalty ,
859
- frequency_penalty = frequency_penalty ,
860
- repeat_penalty = repeat_penalty ,
861
- tfs_z = tfs_z ,
862
- mirostat_mode = mirostat_mode ,
863
- mirostat_tau = mirostat_tau ,
864
- mirostat_eta = mirostat_eta ,
865
- model = model ,
866
- logits_processor = logits_processor ,
867
- grammar = grammar ,
868
- ), stream = stream )
936
+ )
0 commit comments