Thanks to visit codestin.com
Credit goes to github.com

Skip to content

mtmd: Expose helper_decode_image_chunk #13366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 78 additions & 51 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,79 @@ struct decode_embd_batch {
}
};

// Helper function for decoding an image whose embeddings have already been calculated
int32_t mtmd_helper_decode_image_chunk(
mtmd_context * ctx,
struct llama_context * lctx,
const mtmd_input_chunk * chunk,
float * encoded_embd,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
return -1;
}
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (!image_tokens) {
LOG_ERR("failed to decode image chunk: image tokens are null\n");
return -1;
}

int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;

int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);

const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);

if (mtmd_decode_use_mrope(ctx)) {
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
}

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, false);
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
}

while (i_batch < n_img_batches) { // split into batches
int pos_offset = i_batch*n_batch;
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);

LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);

int64_t t1 = ggml_time_ms();
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
llama_set_causal_attn(lctx, true); // restore causal attn
return ret;
}

if (ctx->print_timings) {
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
}

i_batch++;
}

n_past += mtmd_image_tokens_get_n_pos(image_tokens);
*new_n_past = n_past;

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, true);
}
return 0;
}

int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
struct llama_context * lctx,
const mtmd_input_chunk * chunk,
Expand All @@ -591,8 +664,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
int32_t ret;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;

if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
Expand Down Expand Up @@ -637,57 +708,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
if (ctx->print_timings) {
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
}

int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
float * embd = mtmd_get_output_embd(ctx);
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);

const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);

if (mtmd_decode_use_mrope(ctx)) {
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
}

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, false);
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
}

while (i_batch < n_img_batches) { // split into batches
int pos_offset = i_batch*n_batch;
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);

LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);

int64_t t1 = ggml_time_ms();
ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
llama_set_causal_attn(lctx, true); // restore causal attn
llama_batch_free(text_batch);
return ret;
}

if (ctx->print_timings) {
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
}

i_batch++;
}

n_past += mtmd_image_tokens_get_n_pos(image_tokens);
*new_n_past = n_past;

if (mtmd_decode_use_non_causal(ctx)) {
llama_set_causal_attn(lctx, true);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
llama_batch_free(text_batch);
return ret;
}

} else {
GGML_ABORT("chunk type not supported");
}
Expand Down
12 changes: 12 additions & 0 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,18 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
bool logits_last,
llama_pos * new_n_past);

// helper function to decode an image whose embeddings have already been calculated
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
struct llama_context * lctx,
const mtmd_input_chunk * chunk,
float * encoded_embd,
llama_pos n_past,
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past);

/////////////////////////////////////////

// test function, to be used in test-mtmd-c-api.c
Expand Down
Loading