-
Notifications
You must be signed in to change notification settings - Fork 12.1k
llama: Attempt to add ModernBert #14014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
The embedding result seems random and very low. There is something wrong with this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete the files you added in models, we don't need them, just make sure test-tokenizer-0
succeeds with the GGUF.
src/llama-model.cpp
Outdated
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); | ||
cb(inpL, "inp_norm", -1); | ||
|
||
auto * inp_attn = build_attn_inp_kv_unified_iswa(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably become:
auto * inp_attn = build_attn_inp_kv_unified_iswa(); | |
auto * inp_attn = build_attn_inp_no_cache_iswa(); |
And add the corresponding mask logic in llama-graph
. Special attention should be taken about how the SWA works for this model - i.e. is it symmetric or not:
# non-symmetric
token i attends to [i - n_swa, i]
# symmetric:
token i attends to [i - n_swa/2, i + n_swa/2]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@huydt-bti Hey, is the issue that you forgot to make this function so that swa is actually never applied?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CISC No, I made it here:
Lines 349 to 352 in 454d7b7
if (hparams.use_alibi && | |
(hparams.n_swa == 0 || (pos_diff >= -half_n_swa && pos_diff <= half_n_swa))) { | |
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); | |
} else { |
I use build_attn_inp_no_cache()
#14014 (review)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but there's no kq_mask_swa
, so is this even executed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CISC I have just implemented it. Please check again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to add the new arch here:
Lines 13195 to 13203 in 5a8ae30
switch (arch) { | |
case LLM_ARCH_BERT: | |
case LLM_ARCH_JINA_BERT_V2: | |
case LLM_ARCH_NOMIC_BERT: | |
case LLM_ARCH_NOMIC_BERT_MOE: | |
case LLM_ARCH_WAVTOKENIZER_DEC: | |
{ | |
res = nullptr; | |
} break; |
To avoid creating a memory module (a.k.a. KV cache) for these models.
So, since vocab is BPE you need to add Line 1557 in 9f47fa5
Set correct attribute on [MASK] token, similarly to this:Lines 2097 to 2105 in 9f47fa5
|
Yep, I also noticed the same with |
@CISC cc: @ggerganov I tried to do the embedding with various models, but the output results are barely changed among those attempts. Maybe the params load or inference graph is getting problems somewhere. Can you check that part? |
So, I just noticed at least part of the problem: Lines 1567 to 1571 in 3ac6753
We have cls , but not cls_b , so this has to be modified to handle that...
|
@CISC After fixing that, the result is much better now :) but it is still lower than my expectations about ModernBert. Maybe there is problem somewhere else... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything else LGTM, so pending finding the output issue.
src/llama-kv-cache-unified.cpp
Outdated
@@ -1328,6 +1328,12 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | |||
return true; | |||
} | |||
} break; | |||
case LLAMA_SWA_TYPE_SYMMETRIC: | |||
{ | |||
if ( p1 - p0 <= (int32_t) n_swa / 2 || p0 - p1 >= (int32_t) n_swa / 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CISC I see part of the problem! I'm masking the token inside the window, which should be outside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the function isn't used because it belongs to llama_kv_cache_unified
src/llama-graph.cpp
Outdated
@@ -351,6 +351,69 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { | |||
} | |||
} | |||
} | |||
|
|||
// Handle symmetric SWA mask separately if it exists | |||
if (kq_mask_swa) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is unnecessary duplication, it should be handled like this once you add llm_graph_input/build_attn_inp_no_cache_iswa
:
Lines 356 to 370 in d7da8dc
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { | |
if (self_kq_mask) { | |
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); | |
} | |
} | |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { | |
if (self_kq_mask) { | |
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); | |
} | |
if (self_kq_mask_swa) { | |
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); | |
} | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not interleaved swa, so I still prefer using build_attn_inp_no_cache
. I will try to refactor llm_graph_input_attn_no_cache::set_input
, but the mechanisim is different from llm_graph_input_attn_kv_unified
and llm_graph_input_attn_kv_unified_iswa
, since they use kv_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the point is just that you can handle it similarly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed my code - but it's ugly to copy is_masked_swa
and place it inside llm_graph_input_attn_no_cache::set_input
. Do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why you can't split the methods just like in unified
, you can have a is_masked_swa
implementation for no_cache
and a much cleaner set_input
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I don't think that would be cleaner, since the new build_attn_inp_no_cache_swa
will almost be the same as the current build_attn_inp_no_cache
, and we have a new additional build_attn_inp_no_cache
. But I will try that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't matter, build_attn_inp_*
are tiny, the important part is that you can reuse the same set_input_kq_mask
.
I don't know whether my implementation is correct or not