From cd548bd0f14210627798237d5c2ea78acfb88ccb Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 3 Jul 2025 01:57:43 -0400 Subject: [PATCH 1/3] feat: Add support for new mtmd api, add Qwen2.5-VL chat handler --- CMakeLists.txt | 83 +++++---- llama_cpp/llama_chat_format.py | 301 ++++++++++++++++++++++++--------- llama_cpp/mtmd_cpp.py | 280 ++++++++++++++++++++++++++++++ llama_cpp/server/model.py | 14 ++ 4 files changed, 554 insertions(+), 124 deletions(-) create mode 100644 llama_cpp/mtmd_cpp.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 505c024b2..4b06d98b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,46 +143,45 @@ if (LLAMA_BUILD) ) endif() - # if (LLAVA_BUILD) - # if (LLAMA_CUBLAS OR LLAMA_CUDA) - # add_compile_definitions(GGML_USE_CUBLAS) - # add_compile_definitions(GGML_USE_CUDA) - # endif() - # - # if (LLAMA_METAL) - # add_compile_definitions(GGML_USE_METAL) - # endif() - # - # # Building llava - # add_subdirectory(vendor/llama.cpp/tools/mtmd) - # set_target_properties(llava_shared PROPERTIES OUTPUT_NAME "llava") - # - # if (WIN32) - # set_target_properties(llava_shared PROPERTIES CUDA_ARCHITECTURES OFF) - # endif() - # llama_cpp_python_install_target(llava_shared) - # if (WIN32) - # install( - # FILES $ - # DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib - # ) - # install( - # FILES $ - # DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp/lib - # ) - # endif() - # - # # Fix for llava build: Add include directory for llama.h - # # Move these commands after the add_subdirectory call - # target_include_directories(llava PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - # target_include_directories(llava PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) - # - # if (BUILD_SHARED_LIBS) - # target_include_directories(llava_shared PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - # target_include_directories(llava_shared PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) - # endif() - # - # target_include_directories(llama-llava-cli PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - # target_include_directories(llama-minicpmv-cli PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - # endif() + if (LLAVA_BUILD) + if (LLAMA_CUBLAS OR LLAMA_CUDA) + add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_USE_CUDA) + endif() + + if (LLAMA_METAL) + add_compile_definitions(GGML_USE_METAL) + endif() + + # Building llava + add_subdirectory(vendor/llama.cpp/tools/mtmd) + + if (WIN32) + set_target_properties(mtmd PROPERTIES CUDA_ARCHITECTURES OFF) + endif() + llama_cpp_python_install_target(mtmd) + if (WIN32) + install( + FILES $ + DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib + ) + install( + FILES $ + DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp/lib + ) + endif() + + # Fix for mtmd build: Add include directory for llama.h + # Move these commands after the add_subdirectory call + target_include_directories(mtmd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) + target_include_directories(mtmd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) + + if (BUILD_SHARED_LIBS) + target_include_directories(mtmd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) + target_include_directories(mtmd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) + endif() + + # target_include_directories(llama-llava-cli PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) + # target_include_directories(llama-minicpmv-cli PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) + endif() endif() diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 17575c700..11208b09e 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -28,6 +28,7 @@ import numpy as np import numpy.typing as npt +import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar @@ -2651,7 +2652,7 @@ def generate_streaming(tools, functions, function_call, prompt): class Llava15ChatHandler: DEFAULT_SYSTEM_MESSAGE: Optional[str] = ( - "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ) CHAT_FORMAT = ( @@ -2690,70 +2691,72 @@ class Llava15ChatHandler: ) def __init__(self, clip_model_path: str, verbose: bool = True): - import llama_cpp.llava_cpp as llava_cpp + import llama_cpp.mtmd_cpp as mtmd_cpp self.clip_model_path = clip_model_path self.verbose = verbose - - self._llava_cpp = llava_cpp # TODO: Fix + self._mtmd_cpp = mtmd_cpp self._exit_stack = ExitStack() - self._last_image_embed: Optional[ - llava_cpp.CtypesPointer[llava_cpp.llava_image_embed] - ] = None - self._last_image_hash: Optional[int] = None + self.mtmd_ctx: Optional[mtmd_cpp.mtmd_context_p] = None if not os.path.exists(clip_model_path): raise ValueError(f"Clip model path does not exist: {clip_model_path}") + def _init_mtmd_context(self, llama_model: llama.Llama): + """Initialize mtmd context with the llama model.""" + if self.mtmd_ctx is not None: + return # Already initialized + with suppress_stdout_stderr(disable=self.verbose): - clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0) + # Get default parameters + ctx_params = self._mtmd_cpp.mtmd_context_params_default() + # ctx_params.use_gpu = True + ctx_params.print_timings = self.verbose + ctx_params.n_threads = 16 + ctx_params.verbosity = 2 if self.verbose else 0 # GGML_LOG_LEVEL_INFO = 2 + + # Initialize mtmd context + self.mtmd_ctx = self._mtmd_cpp.mtmd_init_from_file( + self.clip_model_path.encode(), + llama_model.model, + ctx_params + ) - if clip_ctx is None: - raise ValueError(f"Failed to load clip model: {clip_model_path}") + if self.mtmd_ctx is None: + raise ValueError(f"Failed to load mtmd context from: {self.clip_model_path}") - self.clip_ctx = clip_ctx + # Check if vision is supported + if not self._mtmd_cpp.mtmd_support_vision(self.mtmd_ctx): + raise ValueError("Vision is not supported by this model") - def clip_free(): + def mtmd_free(): with suppress_stdout_stderr(disable=self.verbose): - self._llava_cpp.clip_free(self.clip_ctx) - - self._exit_stack.callback(clip_free) + if self.mtmd_ctx is not None: + self._mtmd_cpp.mtmd_free(self.mtmd_ctx) + self.mtmd_ctx = None - def last_image_embed_free(): - with suppress_stdout_stderr(disable=self.verbose): - if self._last_image_embed is not None: - self._llava_cpp.llava_image_embed_free(self._last_image_embed) - self._last_image_embed = None - - self._exit_stack.callback(last_image_embed_free) + self._exit_stack.callback(mtmd_free) def load_image(self, image_url: str) -> bytes: return self._load_image(image_url) - def _embed_image_bytes(self, image_bytes: bytes, n_threads_batch: int = 1): - if ( - self._last_image_embed is not None - and self._last_image_hash is not None - and hash(image_bytes) == self._last_image_hash - ): - return self._last_image_embed + def _create_bitmap_from_bytes(self, image_bytes: bytes): + """Create mtmd_bitmap from image bytes.""" + if self.mtmd_ctx is None: + raise ValueError("mtmd context not initialized") + with suppress_stdout_stderr(disable=self.verbose): - # Free the previous image embed - if self._last_image_embed is not None: - self._llava_cpp.llava_image_embed_free(self._last_image_embed) - self._last_image_embed = None - self._last_image_hash = None - embed = self._llava_cpp.llava_image_embed_make_with_bytes( - self.clip_ctx, - n_threads_batch, - (ctypes.c_uint8 * len(image_bytes)).from_buffer( - bytearray(image_bytes) - ), - len(image_bytes), + # Create bitmap from buffer using helper function + bitmap = self._mtmd_cpp.mtmd_helper_bitmap_init_from_buf( + self.mtmd_ctx, + (ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)), + len(image_bytes) ) - self._last_image_embed = embed - self._last_image_hash = hash(image_bytes) - return embed + + if bitmap is None: + raise ValueError("Failed to create bitmap from image bytes") + + return bitmap def __call__( self, @@ -2794,7 +2797,9 @@ def __call__( llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], ]: - assert self.clip_ctx is not None + # Initialize mtmd context + self._init_mtmd_context(llama) + assert self.mtmd_ctx is not None system_prompt = _get_system_message(messages) if system_prompt == "" and self.DEFAULT_SYSTEM_MESSAGE is not None: @@ -2809,54 +2814,131 @@ def __call__( trim_blocks=True, lstrip_blocks=True, ).from_string(self.CHAT_FORMAT) + + # Get the default media marker + media_marker = self._mtmd_cpp.mtmd_default_marker().decode('utf-8') + + # Replace image URLs with media markers in the template text = template.render( messages=messages, add_generation_prompt=True, eos_token=llama.detokenize([llama.token_eos()]), bos_token=llama.detokenize([llama.token_bos()]), ) - split_text = self.split_text_on_image_urls(text, image_urls) + + # Replace image URLs in text with media markers + for image_url in image_urls: + text = text.replace(image_url, media_marker) if self.verbose: print(text, file=sys.stderr) + # Create bitmaps from images + bitmaps = [] + bitmap_cleanup = [] + try: + for image_url in image_urls: + image_bytes = self.load_image(image_url) + bitmap = self._create_bitmap_from_bytes(image_bytes) + bitmaps.append(bitmap) + bitmap_cleanup.append(bitmap) + + # Create input text structure + input_text = self._mtmd_cpp.mtmd_input_text() + input_text.text = text.encode('utf-8') + input_text.add_special = True + input_text.parse_special = True + + # Create input chunks + chunks = self._mtmd_cpp.mtmd_input_chunks_init() + if chunks is None: + raise ValueError("Failed to create input chunks") - # Evaluate prompt - llama.reset() - llama._ctx.kv_cache_clear() - for type_, value in split_text: - if type_ == "text": - tokens = llama.tokenize( - value.encode("utf8"), add_bos=False, special=True + try: + # Tokenize text and images together + bitmap_array = (self._mtmd_cpp.mtmd_bitmap_p_ctypes * len(bitmaps))(*bitmaps) + result = self._mtmd_cpp.mtmd_tokenize( + self.mtmd_ctx, + chunks, + ctypes.byref(input_text), + bitmap_array, + len(bitmaps) ) - if llama.n_tokens + len(tokens) > llama.n_ctx(): - raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}" - ) - llama.eval(tokens) - else: - image_bytes = self.load_image(value) - embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) - if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): - raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" - ) - n_past = ctypes.c_int(llama.n_tokens) - n_past_p = ctypes.pointer(n_past) - with suppress_stdout_stderr(disable=self.verbose): - self._llava_cpp.llava_eval_image_embed( - llama.ctx, - embed, - llama.n_batch, - n_past_p, - ) - # Required to avoid issues with hf tokenizer - llama.input_ids[llama.n_tokens : n_past.value] = -1 - llama.n_tokens = n_past.value - # Get prompt tokens to avoid a cache miss - prompt = llama.input_ids[: llama.n_tokens].tolist() + if result != 0: + raise ValueError(f"Failed to tokenize input: error code {result}") + + # Reset llama context + llama.reset() + llama._ctx.kv_cache_clear() + + # Process each chunk + n_past = llama_cpp.llama_pos(0) + n_chunks = self._mtmd_cpp.mtmd_input_chunks_size(chunks) + + for i in range(n_chunks): + chunk = self._mtmd_cpp.mtmd_input_chunks_get(chunks, i) + if chunk is None: + continue + + chunk_type = self._mtmd_cpp.mtmd_input_chunk_get_type(chunk) + + if chunk_type == self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_TEXT: + # Handle text chunk + n_tokens_out = ctypes.c_size_t() + tokens_ptr = self._mtmd_cpp.mtmd_input_chunk_get_tokens_text( + chunk, ctypes.byref(n_tokens_out) + ) + + if tokens_ptr and n_tokens_out.value > 0: + # Convert ctypes array to Python list + tokens = [tokens_ptr[j] for j in range(n_tokens_out.value)] + + if llama.n_tokens + len(tokens) > llama.n_ctx(): + raise ValueError( + f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}" + ) + llama.eval(tokens) + + elif chunk_type in [self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE, self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_AUDIO]: + # Handle image/audio chunk using helper + chunk_n_tokens = self._mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk) + + if llama.n_tokens + chunk_n_tokens > llama.n_ctx(): + raise ValueError( + f"Prompt exceeds n_ctx: {llama.n_tokens + chunk_n_tokens} > {llama.n_ctx()}" + ) + + new_n_past = llama_cpp.llama_pos(0) + result = self._mtmd_cpp.mtmd_helper_eval_chunk_single( + self.mtmd_ctx, + llama._ctx.ctx, + chunk, + llama_cpp.llama_pos(llama.n_tokens), + llama_cpp.llama_seq_id(0), + llama.n_batch, + False, # logits_last + ctypes.byref(new_n_past) + ) + + if result != 0: + raise ValueError(f"Failed to evaluate chunk: error code {result}") + + # Update llama's token count + llama.n_tokens = new_n_past.value + + # Get prompt tokens to avoid a cache miss + prompt = llama.input_ids[: llama.n_tokens].tolist() + finally: + self._mtmd_cpp.mtmd_input_chunks_free(chunks) + + finally: + # Cleanup bitmaps + for bitmap in bitmap_cleanup: + self._mtmd_cpp.mtmd_bitmap_free(bitmap) + + # Handle response format and tools (same as before) if response_format is not None and response_format["type"] == "json_object": grammar = _grammar_for_response_format(response_format) @@ -2931,6 +3013,7 @@ def __call__( grammar=grammar, logit_bias=logit_bias, ) + if tool is not None: tool_name = tool["function"]["name"] return _convert_completion_to_chat_function( @@ -2943,12 +3026,10 @@ def _load_image(image_url: str) -> bytes: # TODO: Add Pillow support for other image formats beyond (jpg, png) if image_url.startswith("data:"): import base64 - image_bytes = base64.b64decode(image_url.split(",")[1]) return image_bytes else: import urllib.request - with urllib.request.urlopen(image_url) as f: image_bytes = f.read() return image_bytes @@ -2974,6 +3055,7 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): @staticmethod def split_text_on_image_urls(text: str, image_urls: List[str]): + """This method is no longer used in the new implementation.""" def find_first(s: str, substrs: List[str]): for i, substr in enumerate(substrs): pos = s.find(substr) @@ -3373,6 +3455,61 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +class Qwen25VLChatHandler(Llava15ChatHandler): + DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." + + CHAT_FORMAT = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "<|im_start|>user\n" + "{% if message['content'] is string %}" + "{{ message['content'] }}" + "{% else %}" + "{% for content in message['content'] %}" + "{% if content['type'] == 'text' %}" + "{{ content['text'] }}" + "{% elif content['type'] == 'image_url' %}" + "{% if content.image_url is string %}" + "{{ content.image_url }}" + "{% else %}" + "{{ content.image_url.url }}" + "{% endif %}" + "{% endif %}" + "{% endfor %}" + "{% endif %}" + "<|im_end|>\n" + "{% endif %}" + "{% endfor %}" + "<|im_start|>assistant\n" + ) + + def __call__(self, **kwargs): + llama = kwargs['llama'] + + # Clear state for multiple runs + llama.reset() + llama._ctx.kv_cache_clear() + llama.n_tokens = 0 + + if hasattr(llama, 'input_ids'): + llama.input_ids.fill(0) + + # Clear any handler state + if hasattr(self, '_last_image_embed'): + self._last_image_embed = None + self._last_image_hash = None + + if self.verbose: + messages = kwargs.get('messages', []) + image_count = len(self.get_image_urls(messages)) + print(f"Minimal - Cleared state, processing {image_count} images", file=sys.stderr) + + # Use parent implementation + return super().__call__(**kwargs) + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, diff --git a/llama_cpp/mtmd_cpp.py b/llama_cpp/mtmd_cpp.py new file mode 100644 index 000000000..a45f8f406 --- /dev/null +++ b/llama_cpp/mtmd_cpp.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +from ctypes import ( + c_bool, + c_char_p, + c_int, + c_uint8, + c_uint32, + c_float, + c_void_p, + c_size_t, + POINTER, + _Pointer, # type: ignore + Structure, + byref, +) +import pathlib +from typing import ( + Union, + NewType, + Optional, + TYPE_CHECKING, +) + +import llama_cpp.llama_cpp as llama_cpp + +from llama_cpp._ctypes_extensions import ( + load_shared_library, + ctypes_function_for_shared_library, +) + +if TYPE_CHECKING: + from llama_cpp._ctypes_extensions import ( + CtypesArray, + ) + + +# Specify the base name of the shared library to load +_libmtmd_base_name = "mtmd" +_libmtmd_override_path = os.environ.get("MTMD_CPP_LIB") +_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path() + +# Load the library +_libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path) + +ctypes_function = ctypes_function_for_shared_library(_libmtmd) + +################################################ +# mtmd.h types +################################################ + +# Opaque types +mtmd_context_p = NewType("mtmd_context_p", int) +mtmd_context_p_ctypes = c_void_p + +mtmd_bitmap_p = NewType("mtmd_bitmap_p", int) +mtmd_bitmap_p_ctypes = c_void_p + +mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int) +mtmd_image_tokens_p_ctypes = c_void_p + +mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int) +mtmd_input_chunk_p_ctypes = c_void_p + +mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int) +mtmd_input_chunks_p_ctypes = c_void_p + +# Enums +MTMD_INPUT_CHUNK_TYPE_TEXT = 0 +MTMD_INPUT_CHUNK_TYPE_IMAGE = 1 +MTMD_INPUT_CHUNK_TYPE_AUDIO = 2 + +# Structures +class mtmd_context_params(Structure): + _fields_ = [ + ("use_gpu", c_bool), + ("print_timings", c_bool), + ("n_threads", c_int), + ("verbosity", c_int), # ggml_log_level + ("image_marker", c_char_p), + ("media_marker", c_char_p), + ] + +class mtmd_input_text(Structure): + _fields_ = [ + ("text", c_char_p), + ("add_special", c_bool), + ("parse_special", c_bool), + ] + +################################################ +# mtmd.h functions +################################################ + +# MTMD_API const char * mtmd_default_marker(void); +@ctypes_function("mtmd_default_marker", [], c_char_p) +def mtmd_default_marker() -> bytes: + ... + +# MTMD_API struct mtmd_context_params mtmd_context_params_default(void); +@ctypes_function("mtmd_context_params_default", [], mtmd_context_params) +def mtmd_context_params_default() -> mtmd_context_params: + ... + +# MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, +# const struct llama_model * text_model, +# const struct mtmd_context_params ctx_params); +@ctypes_function( + "mtmd_init_from_file", + [c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params], + mtmd_context_p_ctypes +) +def mtmd_init_from_file( + mmproj_fname: bytes, + text_model: llama_cpp.llama_model_p, + ctx_params: mtmd_context_params, + /, +) -> Optional[mtmd_context_p]: + ... + +# MTMD_API void mtmd_free(mtmd_context * ctx); +@ctypes_function("mtmd_free", [mtmd_context_p_ctypes], None) +def mtmd_free(ctx: mtmd_context_p, /): + ... + +# MTMD_API bool mtmd_support_vision(mtmd_context * ctx); +@ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool) +def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool: + ... + +# MTMD_API mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, uint32_t ny, const unsigned char * data); +@ctypes_function( + "mtmd_bitmap_init", + [c_uint32, c_uint32, POINTER(c_uint8)], + mtmd_bitmap_p_ctypes +) +def mtmd_bitmap_init( + nx: Union[c_uint32, int], + ny: Union[c_uint32, int], + data: CtypesArray[c_uint8], + /, +) -> Optional[mtmd_bitmap_p]: + ... + +# MTMD_API void mtmd_bitmap_free(mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None) +def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /): + ... + +# MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); +@ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes) +def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]: + ... + +# MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); +@ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None) +def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /): + ... + +# MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks); +@ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t) +def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int: + ... + +# MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx); +@ctypes_function( + "mtmd_input_chunks_get", + [mtmd_input_chunks_p_ctypes, c_size_t], + mtmd_input_chunk_p_ctypes +) +def mtmd_input_chunks_get( + chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], / +) -> Optional[mtmd_input_chunk_p]: + ... + +# MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, +# mtmd_input_chunks * output, +# const mtmd_input_text * text, +# const mtmd_bitmap ** bitmaps, +# size_t n_bitmaps); +@ctypes_function( + "mtmd_tokenize", + [ + mtmd_context_p_ctypes, + mtmd_input_chunks_p_ctypes, + POINTER(mtmd_input_text), + POINTER(mtmd_bitmap_p_ctypes), + c_size_t, + ], + c_int, +) +def mtmd_tokenize( + ctx: mtmd_context_p, + output: mtmd_input_chunks_p, + text: "_Pointer[mtmd_input_text]", + bitmaps: CtypesArray[mtmd_bitmap_p_ctypes], + n_bitmaps: Union[c_size_t, int], + /, +) -> int: + ... + +# MTMD_API size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk); +@ctypes_function("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t) +def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int: + ... + +# MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk); +@ctypes_function("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int) +def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int: + ... + +# MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output); +@ctypes_function( + "mtmd_input_chunk_get_tokens_text", + [mtmd_input_chunk_p_ctypes, POINTER(c_size_t)], + POINTER(llama_cpp.llama_token) +) +def mtmd_input_chunk_get_tokens_text( + chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", / +) -> Optional["_Pointer[llama_cpp.llama_token]"]: + ... + +################################################ +# mtmd-helper.h functions +################################################ + +# MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len); +@ctypes_function( + "mtmd_helper_bitmap_init_from_buf", + [mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t], + mtmd_bitmap_p_ctypes +) +def mtmd_helper_bitmap_init_from_buf( + ctx: mtmd_context_p, + buf: CtypesArray[c_uint8], + length: Union[c_size_t, int], + /, +) -> Optional[mtmd_bitmap_p]: + ... + +# MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks); +@ctypes_function("mtmd_helper_get_n_tokens", [mtmd_input_chunks_p_ctypes], c_size_t) +def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: + ... + +# MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, +# struct llama_context * lctx, +# const mtmd_input_chunk * chunk, +# llama_pos n_past, +# llama_seq_id seq_id, +# int32_t n_batch, +# bool logits_last, +# llama_pos * new_n_past); +@ctypes_function( + "mtmd_helper_eval_chunk_single", + [ + mtmd_context_p_ctypes, + llama_cpp.llama_context_p_ctypes, + mtmd_input_chunk_p_ctypes, + llama_cpp.llama_pos, + llama_cpp.llama_seq_id, + c_int, + c_bool, + POINTER(llama_cpp.llama_pos), + ], + c_int, +) +def mtmd_helper_eval_chunk_single( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunk: mtmd_input_chunk_p, + n_past: llama_cpp.llama_pos, + seq_id: llama_cpp.llama_seq_id, + n_batch: Union[c_int, int], + logits_last: Union[c_bool, bool], + new_n_past: "_Pointer[llama_cpp.llama_pos]", + /, +) -> int: + ... diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index c6716f919..11bd363b5 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -171,6 +171,20 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: chat_handler = llama_cpp.llama_chat_format.MiniCPMv26ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) + elif settings.chat_format == "qwen2.5-vl": + assert settings.clip_model_path is not None, "clip model not found" + if settings.hf_model_repo_id is not None: + chat_handler = ( + llama_cpp.llama_chat_format.Qwen25VLChatHandler.from_pretrained( + repo_id=settings.hf_model_repo_id, + filename=settings.clip_model_path, + verbose=settings.verbose, + ) + ) + else: + chat_handler = llama_cpp.llama_chat_format.Qwen25VLChatHandler( + clip_model_path=settings.clip_model_path, verbose=settings.verbose + ) elif settings.chat_format == "hf-autotokenizer": assert ( settings.hf_pretrained_model_name_or_path is not None From 07a979f9077f28b180f4102b9f089246327b96df Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 3 Jul 2025 02:01:24 -0400 Subject: [PATCH 2/3] fix: Use num_threads from llama model for mtmd --- llama_cpp/llama_chat_format.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 11208b09e..a288db7b0 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2710,9 +2710,9 @@ def _init_mtmd_context(self, llama_model: llama.Llama): with suppress_stdout_stderr(disable=self.verbose): # Get default parameters ctx_params = self._mtmd_cpp.mtmd_context_params_default() - # ctx_params.use_gpu = True + ctx_params.use_gpu = True # TODO: Make this configurable ctx_params.print_timings = self.verbose - ctx_params.n_threads = 16 + ctx_params.n_threads = llama_model.n_threads ctx_params.verbosity = 2 if self.verbose else 0 # GGML_LOG_LEVEL_INFO = 2 # Initialize mtmd context From 6f3f0bf4d1a7db581ea0f87b89212e3a08b5ded7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 3 Jul 2025 02:04:50 -0400 Subject: [PATCH 3/3] docs: Add Qwen2.5-VL to README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e00456580..088a23779 100644 --- a/README.md +++ b/README.md @@ -505,6 +505,7 @@ Below are the supported multi-modal models and their respective chat handlers (P | [nanollava](https://huggingface.co/abetlen/nanollava-gguf) | `NanollavaChatHandler` | `nanollava` | | [llama-3-vision-alpha](https://huggingface.co/abetlen/llama-3-vision-alpha-gguf) | `Llama3VisionAlphaChatHandler` | `llama-3-vision-alpha` | | [minicpm-v-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) | `MiniCPMv26ChatHandler` | `minicpm-v-2.6` | +| [qwen2.5-vl](https://huggingface.co/unsloth/Qwen2.5-VL-3B-Instruct-GGUF) | `Qwen25VLChatHandler` | `qwen2.5-vl` | Then you'll need to use a custom chat handler to load the clip model and process the chat messages and images.