Thanks to visit codestin.com
Credit goes to github.com

Skip to content

llama: Attempt to add ModernBert #14014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open

Conversation

huydt84
Copy link
Collaborator

@huydt84 huydt84 commented Jun 4, 2025

I don't know whether my implementation is correct or not

@github-actions github-actions bot added the python python script changes label Jun 4, 2025
@huydt84 huydt84 marked this pull request as draft June 4, 2025 15:27
@huydt84 huydt84 marked this pull request as ready for review June 4, 2025 15:36
@huydt84
Copy link
Collaborator Author

huydt84 commented Jun 4, 2025

hparams.set_swa_pattern can't work properly with ModernBert

@huydt84 huydt84 marked this pull request as draft June 4, 2025 15:40
@huydt84
Copy link
Collaborator Author

huydt84 commented Jun 4, 2025

The embedding result seems random and very low. There is something wrong with this

@huydt84 huydt84 marked this pull request as ready for review June 4, 2025 16:21
Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete the files you added in models, we don't need them, just make sure test-tokenizer-0 succeeds with the GGUF.

@huydt84 huydt84 requested a review from CISC June 4, 2025 22:55
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
cb(inpL, "inp_norm", -1);

auto * inp_attn = build_attn_inp_kv_unified_iswa();
Copy link
Member

@ggerganov ggerganov Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably become:

Suggested change
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_no_cache_iswa();

And add the corresponding mask logic in llama-graph. Special attention should be taken about how the SWA works for this model - i.e. is it symmetric or not:

# non-symmetric
token i attends to [i - n_swa, i]

# symmetric:
token i attends to [i - n_swa/2, i + n_swa/2]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huydt-bti Hey, is the issue that you forgot to make this function so that swa is actually never applied?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC No, I made it here:

if (hparams.use_alibi &&
(hparams.n_swa == 0 || (pos_diff >= -half_n_swa && pos_diff <= half_n_swa))) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {

I use build_attn_inp_no_cache()
#14014 (review)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but there's no kq_mask_swa, so is this even executed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC I have just implemented it. Please check again

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to add the new arch here:

llama.cpp/src/llama-model.cpp

Lines 13195 to 13203 in 5a8ae30

switch (arch) {
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_WAVTOKENIZER_DEC:
{
res = nullptr;
} break;

To avoid creating a memory module (a.k.a. KV cache) for these models.

@huydt84 huydt84 requested a review from ggerganov June 5, 2025 13:55
@CISC
Copy link
Collaborator

CISC commented Jun 5, 2025

So, since vocab is BPE you need to add modern-bert vocab handling a few places:

tokenizer_pre == "roberta-bpe") {

Set correct attribute on [MASK] token, similarly to this:

llama.cpp/src/llama-vocab.cpp

Lines 2097 to 2105 in 9f47fa5

if (false
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|| _contains_any(general_arch, {"nomic-bert-moe"})
) {
if (token_to_id.count("<mask>") == 0) {
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
} else {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
}

@CISC
Copy link
Collaborator

CISC commented Jun 5, 2025

The embedding result seems random and very low. There is something wrong with this

Yep, I also noticed the same with jina-reranker-v2, most likely the same issue, will investigate.

@huydt84
Copy link
Collaborator Author

huydt84 commented Jun 8, 2025

@CISC cc: @ggerganov

I tried to do the embedding with various models, but the output results are barely changed among those attempts. Maybe the params load or inference graph is getting problems somewhere. Can you check that part?
This is the model implementation in Huggingface: https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/modernbert/modeling_modernbert.py

@CISC
Copy link
Collaborator

CISC commented Jun 8, 2025

So, I just noticed at least part of the problem:

llama.cpp/src/llama-graph.cpp

Lines 1567 to 1571 in 3ac6753

if (cls != nullptr && cls_b != nullptr) {
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
cur = ggml_tanh(ctx0, cur);

We have cls, but not cls_b, so this has to be modified to handle that...

@huydt84
Copy link
Collaborator Author

huydt84 commented Jun 9, 2025

We have cls, but not cls_b, so this has to be modified to handle that...

@CISC After fixing that, the result is much better now :) but it is still lower than my expectations about ModernBert. Maybe there is problem somewhere else...

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything else LGTM, so pending finding the output issue.

@@ -1328,6 +1328,12 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
if ( p1 - p0 <= (int32_t) n_swa / 2 || p0 - p1 >= (int32_t) n_swa / 2) {
Copy link
Collaborator Author

@huydt84 huydt84 Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC I see part of the problem! I'm masking the token inside the window, which should be outside

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the function isn't used because it belongs to llama_kv_cache_unified

@@ -351,6 +351,69 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
}
}

// Handle symmetric SWA mask separately if it exists
if (kq_mask_swa) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is unnecessary duplication, it should be handled like this once you add llm_graph_input/build_attn_inp_no_cache_iswa:

void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_kq_mask_swa) {
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not interleaved swa, so I still prefer using build_attn_inp_no_cache. I will try to refactor llm_graph_input_attn_no_cache::set_input, but the mechanisim is different from llm_graph_input_attn_kv_unified and llm_graph_input_attn_kv_unified_iswa, since they use kv_state

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the point is just that you can handle it similarly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed my code - but it's ugly to copy is_masked_swa and place it inside llm_graph_input_attn_no_cache::set_input. Do you have any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you can't split the methods just like in unified, you can have a is_masked_swa implementation for no_cache and a much cleaner set_input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I don't think that would be cleaner, since the new build_attn_inp_no_cache_swa will almost be the same as the current build_attn_inp_no_cache, and we have a new additional build_attn_inp_no_cache. But I will try that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't matter, build_attn_inp_* are tiny, the important part is that you can reuse the same set_input_kq_mask.

@CISC CISC added the model Model specific label Jun 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants