diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2232a7d82349e..22ff75dc8850d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -812,6 +812,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" + if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152": + # ref: https://huggingface.co/answerdotai/ModernBERT-base + res = "modern-bert" if res is None: logger.warning("\n") @@ -3949,6 +3952,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification") +class ModernBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.MODERN_BERT + + def set_vocab(self): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) + self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"]) + self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # These layers act as MLM head, so we don't need them + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("RobertaModel", "RobertaForSequenceClassification") class RobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 2f733f0973686..53294fa5f82bc 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, + {"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9b2143c7c2eaa..2c35ce4dd45c7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -147,6 +147,7 @@ class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" @@ -289,6 +290,7 @@ class MODEL_ARCH(IntEnum): STARCODER = auto() REFACT = auto() BERT = auto() + MODERN_BERT = auto() NOMIC_BERT = auto() NOMIC_BERT_MOE = auto() JINA_BERT_V2 = auto() @@ -479,6 +481,7 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() CLS = auto() # classifier + CLS_NORM = auto() # classifier normalization CLS_OUT = auto() # classifier output projection CONV1D = auto() CONVNEXT_DW = auto() @@ -571,6 +574,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", + MODEL_ARCH.MODERN_BERT: "modern-bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", @@ -761,6 +765,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_NORM: "cls.norm", MODEL_TENSOR.CLS_OUT: "cls.output", MODEL_TENSOR.CONV1D: "conv1d", MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", @@ -1051,6 +1056,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, ], + MODEL_ARCH.MODERN_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_NORM, + MODEL_TENSOR.CLS_OUT, + ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 54ca0c33fd336..429f408699fbf 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -813,6 +813,9 @@ def add_rope_dimension_sections(self, dims: Sequence[int]) -> None: def add_rope_freq_base(self, value: float) -> None: self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value) + def add_rope_freq_base_swa(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value) + def add_rope_scaling_type(self, value: RopeScalingType) -> None: self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5e3f01754bf07..d9e78af4399eb 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -16,6 +16,7 @@ class TensorNameMap: "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert + "embeddings.tok_embeddings", # modern-bert "language_model.embedding.word_embeddings", # persimmon "wte", # gpt2 "transformer.embd.wte", # phi2 @@ -42,6 +43,7 @@ class TensorNameMap: MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert + "embeddings.norm", # modern-bert "emb_ln", # nomic-bert "transformer.norm", # openelm "rwkv.blocks.0.pre_ln", # rwkv @@ -134,6 +136,7 @@ class TensorNameMap: "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 "model.layers.{bid}.input_layernorm", # llama4 + "layers.{bid}.attn_norm", # modern-bert ), # Attention norm 2 @@ -161,6 +164,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm + "layers.{bid}.attn.Wqkv", # modern-bert ), # Attention query @@ -236,6 +240,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone "model.layers.{bid}.self_attn.o_proj", # llama4 + "layers.{bid}.attn.Wo", # modern-bert ), # Attention output norm @@ -245,6 +250,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm1", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_1", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx + "layers.{bid}.mlp_norm" # modern-bert ), MODEL_TENSOR.ATTN_POST_NORM: ( @@ -340,6 +346,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone "model.layers.{bid}.feed_forward.up_proj", # llama4 + "layers.{bid}.mlp.Wi" # modern-bert ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -422,6 +429,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.{bid}.feed_forward.down_proj", # llama4 + "layers.{bid}.mlp.Wo" # modern-bert ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -832,12 +840,18 @@ class TensorNameMap: # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 + "final_norm", # modern-bert ), MODEL_TENSOR.CLS: ( "classifier", # jina "classifier.dense", # roberta "pre_classifier", # distillbert + "head.dense", # modern-bert + ), + + MODEL_TENSOR.CLS_NORM: ( + "head.norm", # modern-bert ), MODEL_TENSOR.CLS_OUT: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a3e7c861ca02f..6f4cb68c66d2e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -18,6 +18,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, @@ -150,6 +151,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, @@ -481,6 +483,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, + { + LLM_ARCH_MODERN_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, { LLM_ARCH_NOMIC_BERT, { @@ -1619,6 +1638,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 168fdcb401cfd..41d340cc1bea4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -22,6 +22,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_JINA_BERT_V2, @@ -154,6 +155,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, @@ -348,6 +350,7 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_CLS_OUT, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3a113d1bcfb2a..a1f5b3043f79e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2002,6 +2002,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); llama_set_param(model->cls, param_filter, param_filter_ud); llama_set_param(model->cls_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 337fb5cb0df36..2cce90ef427f3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -265,92 +265,131 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { - if (kq_mask) { - if (cparams.causal_attn) { - const int64_t n_kv = ubatch->n_tokens; - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - - GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); - float * data = (float *) kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch->seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - // TODO: fix indexing [UBATCH_IDX] - for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { - if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); - } else { - f = 0.0f; - } - break; - } - } + // Helper function for SWA masking logic - mirrors llama_kv_cache_unified::is_masked_swa + auto is_masked_swa = [this](llama_pos p0, llama_pos p1) -> bool { + assert(p0 >= 0 && p1 >= 0); - data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; - } - } + switch (hparams.swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) hparams.n_swa) { + return true; } - } - } - } else { - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - const int64_t n_stride = ubatch->n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); - - float * data = (float *) kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch->seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - // TODO: fix indexing [UBATCH_IDX] - for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { - if (ubatch->seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]); - } else { - f = 0.0f; - } - break; - } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / hparams.n_swa) * hparams.n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) hparams.n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + }; + + // Helper function for setting attention mask + auto set_mask = [this, ubatch, &is_masked_swa](ggml_tensor * mask, bool apply_swa) { + if (!mask) { + return; + } + + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + const int64_t n_kv = ubatch->n_tokens; + const int64_t n_stride = cparams.causal_attn ? n_kv : n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer)); + float * data = (float *) mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch->seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + const llama_pos pos_j = ubatch->pos[tj]; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + const llama_pos pos_i = ubatch->pos[ti]; + float f = -INFINITY; + + // Check sequence match + bool sequence_match = false; + for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { + if (ubatch->seq_id[s0][s] == seq_id) { + sequence_match = true; + break; + } + } + + if (sequence_match) { + bool masked = false; + + // Apply causal attention if enabled + if (cparams.causal_attn && pos_i > pos_j) { + masked = true; + } + + // Apply SWA masking if needed + if (!masked && apply_swa) { + masked = masked || is_masked_swa(pos_i, pos_j); } - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + if (!masked) { + if (hparams.use_alibi) { + f = -std::abs(pos_i - pos_j); + } else { + f = 0.0f; + } + } } - } - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + const int idx = h*(n_tokens*n_tokens) + tj*n_stride + ti; + data[idx] = f; } } + + // Pad the rest of the row with -INFINITY + for (int i = n_tokens; i < n_stride; ++i) { + const int idx = h*(n_tokens*n_tokens) + tj*n_stride + i; + data[idx] = -INFINITY; + } } } } - } + + // Pad any remaining entries with -INFINITY + for (int tj = n_tokens; tj < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++tj) { + for (int ti = 0; ti < n_stride; ++ti) { + const int idx = 0*(n_tokens*n_tokens) + tj*n_stride + ti; + data[idx] = -INFINITY; + } + } + }; + + // Set regular attention mask + set_mask(kq_mask, false); + + // Set SWA attention mask if available + set_mask(kq_mask_swa, true); } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { @@ -1174,6 +1213,15 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; + // Create SWA mask for symmetric sliding window attention if SWA is enabled + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE && hparams.n_swa > 0) { + inp->kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->kq_mask_swa); + + inp->kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask_swa, GGML_TYPE_F16) : inp->kq_mask_swa; + } + return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); } @@ -1197,7 +1245,8 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto & kq_mask = inp->get_kq_mask(); + // Select appropriate mask based on SWA type + const auto & kq_mask = hparams.is_swa(il) && inp->kq_mask_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; @@ -1521,7 +1570,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -1572,6 +1622,11 @@ void llm_graph_context::build_pooling( } cur = ggml_tanh(ctx0, cur); + if (cls_norm) { + // normalization head + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 if (cls_out) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 87813119b1a3c..e4ff8d1512072 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,14 +219,18 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { hparams(hparams), cparams(cparams) { } + ~llm_graph_input_attn_no_cache() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return kq_mask_swa_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + ggml_tensor * kq_mask_swa = nullptr; // F32 [n_tokens, n_batch] - for SWA + ggml_tensor * kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch] - for SWA const llama_hparams & hparams; const llama_cparams & cparams; @@ -621,7 +625,8 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; }; // TODO: better name diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a5dd9..c411e483d74a7 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -8,6 +8,12 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern) { } } +void llama_hparams::set_swa_pattern(uint32_t n_pattern, uint32_t remainder) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern != remainder); + } +} + bool llama_hparams::is_swa_any() const { for (uint32_t il = 0; il < n_layer; ++il) { if (swa_layers[il]) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index b2bcb8b01a18b..6ae6724fe712d 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -18,6 +18,7 @@ enum llama_swa_type { LLAMA_SWA_TYPE_NONE = 0, LLAMA_SWA_TYPE_STANDARD = 1, LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_SYMMETRIC = 3, }; struct llama_hparams_posnet { @@ -162,6 +163,18 @@ struct llama_hparams { // etc ... void set_swa_pattern(uint32_t n_pattern); + // Overload that allows specifying which position in the pattern is dense + // The remainder parameter specifies which position in the pattern is dense + // example: n_pattern = 3, remainder = 2 + // il == 0: swa (0 % 3 = 0, which is not equal to 2) + // il == 1: swa (1 % 3 = 1, which is not equal to 2) + // il == 2: dense (2 % 3 = 2, which equals 2) + // il == 3: swa (3 % 3 = 0, which is not equal to 2) + // il == 4: swa (4 % 3 = 1, which is not equal to 2) + // il == 5: dense (5 % 3 = 2, which equals 2) + // etc ... + void set_swa_pattern(uint32_t n_pattern, uint32_t remainder); + // return true if one of the layers is SWA bool is_swa_any() const; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index b17936abdb4c6..a0a2f78a5791d 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1328,6 +1328,16 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { return true; } } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; } return false; diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index a70b9892347cb..69de2eb961b4e 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -265,6 +265,7 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model.output_norm_enc); add_tensor(model.cls); add_tensor(model.cls_b); + add_tensor(model.cls_norm); add_tensor(model.cls_out); add_tensor(model.cls_out_b); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dcc8b0be72563..70138ce956740 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -35,12 +35,14 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_80M: return "80M"; case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_149M: return "149M"; case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; case LLM_TYPE_250M: return "250M"; case LLM_TYPE_270M: return "270M"; case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_395M: return "395M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; @@ -509,6 +511,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); @@ -720,6 +723,24 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + hparams.set_swa_pattern(3, 0); + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2212,6 +2233,33 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2*n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } break; case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings @@ -6182,6 +6230,122 @@ struct llm_build_bert : public llm_graph_context { } }; +struct llm_build_modern_bert : public llm_graph_context { + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos() + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + float rope_theta = il % 3 == 0 ? hparams.rope_freq_base_train : hparams.rope_freq_base_train_swa; + + // attention layer norm + if (model.layers[il].attn_norm) { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + } + + // self-attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention layer norm + cur = build_norm(cur, model.layers[il].attn_out_norm, nullptr, LLM_NORM, il); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_bloom : public llm_graph_context { llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -13596,6 +13760,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: { res = nullptr; } break; @@ -13703,6 +13868,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_MODERN_BERT: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_BLOOM: { llm = std::make_unique(*this, params, gf); @@ -13938,7 +14107,7 @@ llm_graph_result_ptr llama_model::build_graph( } // add on pooling layer - llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b, cls_norm); return std::move(llm->res); } @@ -14092,6 +14261,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: diff --git a/src/llama-model.h b/src/llama-model.h index 06e6c687943cc..ec4ee1e177c12 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -28,12 +28,14 @@ enum llm_type { LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, + LLM_TYPE_149M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, LLM_TYPE_250M, LLM_TYPE_270M, LLM_TYPE_335M, + LLM_TYPE_395M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, @@ -344,6 +346,7 @@ struct llama_model { struct ggml_tensor * output = nullptr; struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + struct ggml_tensor * cls_norm = nullptr; // classifier struct ggml_tensor * cls = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index dd2251ef3cbef..b7f9482ea6e7e 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1554,7 +1554,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-code" || - tokenizer_pre == "roberta-bpe") { + tokenizer_pre == "roberta-bpe" || + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "refact") { @@ -2104,6 +2105,12 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else { _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); } + } else if (_contains_any(general_arch, {"modern-bert"})) { + if (token_to_id.count("[MASK]") == 0) { + LLAMA_LOG_WARN("%s: Mask token not found in vocab!\n", __func__); + } else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } else if (_contains_any(model_name, {"phi-3", "phi3"})) { for (auto id : cache_special_tokens) { _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);