From 79ebef84f9c9823009da1077a22e26cb175a26e9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 6 Apr 2025 18:47:56 +0200 Subject: [PATCH 01/20] llama4 conversion --- convert_hf_to_gguf.py | 65 +++++++++++++++-- convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 26 +++++++ models/ggml-vocab-llama4.gguf.inp | 112 ++++++++++++++++++++++++++++++ models/ggml-vocab-llama4.gguf.out | 46 ++++++++++++ 5 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 models/ggml-vocab-llama4.gguf.inp create mode 100644 models/ggml-vocab-llama4.gguf.out diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cfe94deaf76ef..0cab6cc937789 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -714,6 +714,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224": # ref: https://huggingface.co/inclusionAI/Ling-lite res = "bailingmoe" + if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": + # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct + res = "llama4" if res is None: logger.warning("\n") @@ -1608,6 +1611,7 @@ def prepare_tensors(self): @Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA + undo_permute = True def set_vocab(self): try: @@ -1672,10 +1676,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - if name.endswith(("q_proj.weight", "q_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight", "k_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if self.undo_permute: + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately if name.find("block_sparse_moe.experts") != -1: @@ -1752,6 +1757,58 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("Llama4ForConditionalGeneration") +class Llama4Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA4 + has_vision: bool = False + undo_permute = False + + # TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config" + # same with llama, but we need to merge the text_config into the root level of hparams + def __init__(self, *args, **kwargs): + hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0]) + if "text_config" in hparams: + hparams = {**hparams, **hparams["text_config"]} + kwargs["hparams"] = hparams + super().__init__(*args, **kwargs) + if "vision_config" in hparams: + logger.info("Has vision encoder, but it will be ignored") + self.has_vision = True + # hacky renaming + self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] + self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + # TODO @ngxson : this is for testing, will be cleaned up later + self.gguf_writer.add_uint32("llama4.interleave_moe_layer_step", self.hparams["interleave_moe_layer_step"]) + self.gguf_writer.add_uint32("llama4.no_rope_layer_interval", 4) # every 4th layer + self.gguf_writer.add_uint32("llama4.expert_feed_forward_length", self.hparams["intermediate_size_moe"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("language_model.", "") + name = name.replace("feed_forward.", "mlp.") # a bit hacky for now + name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now + + if "gate_up_proj" in name: + name_up = name.replace("gate_up_proj", "up_proj.weight") + name_gate = name.replace("gate_up_proj", "gate_proj.weight") + dim_half = data_torch.shape[-1] // 2 + gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2) + return [ + (self.map_tensor_name(name_gate), gate_proj_weight), + (self.map_tensor_name(name_up), up_proj_weight) + ] + + if name.endswith("down_proj"): + name += ".weight" + data_torch = data_torch.transpose(-1, -2) + + if "multi_modal_projector" in name or "vision_model" in name: + return [] + return super().modify_tensors(data_torch, name, bid) + + @Model.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 1b86f4c90acf6..ce6104da4a0e9 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -113,6 +113,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, + {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3a52cfd1e39ac..d4f4e117993e2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -116,6 +116,7 @@ class LLM: RESIDUAL_SCALE = "{arch}.residual_scale" EMBEDDING_SCALE = "{arch}.embedding_scale" TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" + INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -227,6 +228,7 @@ class GGUFType: class MODEL_ARCH(IntEnum): LLAMA = auto() + LLAMA4 = auto() DECI = auto() FALCON = auto() BAICHUAN = auto() @@ -431,6 +433,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", MODEL_ARCH.DECI: "deci", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.BAICHUAN: "baichuan", @@ -654,6 +657,29 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.LLAMA4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.DECI: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/models/ggml-vocab-llama4.gguf.inp b/models/ggml-vocab-llama4.gguf.inp new file mode 100644 index 0000000000000..9baf7d77ae6b5 --- /dev/null +++ b/models/ggml-vocab-llama4.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama4.gguf.out b/models/ggml-vocab-llama4.gguf.out new file mode 100644 index 0000000000000..7ca46ce597b85 --- /dev/null +++ b/models/ggml-vocab-llama4.gguf.out @@ -0,0 +1,46 @@ + 1190 220 32 220 18215 7112 + 50 16800 258 + + 220 + 256 + 277 + 197 + 198 + 368 + 2946 + 3271 + 19873 3817 + 39715 3817 + 19873 7353 + 39715 7353 + 39715 7353 13 + 19873 24 3817 13 + 39715 24 3817 13 + 544 373 9522 112 247 26 36315 + 99 39923 220 35 9607 21498 21470 3679 9433 + 1595 7653 633 79829 34051 1636 + 8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234 + 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21 + 19873 + 39715 + 220 39715 + 256 39715 + 277 39715 + 277 39715 198 277 39715 + 330 + 198 319 + 19 7359 + 19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 + 17931 4959 + 31 + 1922 + 12325 + 12325 31 + 12325 1922 + 12325 12325 + 12325 12325 31 + 12325 12325 1922 + 12325 12325 12325 + 47 19811 12077 + 3260 3579 + 198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56 From b19dbd0149ec0fbe38aff12325304a1250f0b135 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 00:02:12 +0200 Subject: [PATCH 02/20] initial support, no chat template --- include/llama.h | 1 + src/llama-arch.cpp | 30 ++++++++++++++++ src/llama-arch.h | 1 + src/llama-graph.cpp | 8 +++++ src/llama-hparams.h | 5 +++ src/llama-model.cpp | 85 +++++++++++++++++++++++++++++++++++++++------ src/llama-vocab.cpp | 3 +- 7 files changed, 121 insertions(+), 12 deletions(-) diff --git a/include/llama.h b/include/llama.h index fca2b034ba270..5fd99f4c8b99e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -110,6 +110,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, }; enum llama_rope_type { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 047782e7d0fc8..fbbba1b7773f0 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -6,6 +6,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, @@ -233,6 +234,35 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_LLAMA4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_DECI, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 297cfa4dae571..9345cb66699af 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -10,6 +10,7 @@ enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, LLM_ARCH_DECI, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cec203df49268..9555201f8813e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -841,6 +841,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(selection_probs, "ffn_moe_probs_biased", il); } + // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k + // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198 + if (arch == LLM_ARCH_LLAMA4) { + selection_probs = logits; + } + // select experts ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); @@ -914,6 +920,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( moe_out = ggml_cont(ctx0, moe_out); } + cb(moe_out, "ffn_moe_out", il); + return moe_out; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bb17ba86dc2fb..361756907e18f 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -112,6 +112,11 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + // TODO @ngxson : variable names taken from python code, we can rename it later + uint32_t interleave_moe_layer_step = 2; // TODO read from gguf + uint32_t no_rope_layer_interval = 4; // TODO read from gguf + uint32_t attn_temperature_tuning = 4; // TODO read from gguf + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ca6e3ab2caeb1..fa447d343fc52 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -524,8 +524,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); if (hparams.n_expert == 8) { switch (hparams.n_layer) { @@ -1631,6 +1633,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const auto tn = LLM_TN(arch); switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_REFACT: case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: @@ -1648,6 +1651,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.interleave_moe_layer_step == 0; + auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -1673,7 +1678,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - if (n_expert == 0) { + int n_ff_exp = hparams.n_ff_exp; + if (n_expert == 0 || !is_moe_layer) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); @@ -1684,9 +1690,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch (only used by llama 4 for now) + if (arch == LLM_ARCH_LLAMA4) { + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } } } } break; @@ -4209,6 +4223,10 @@ struct llm_build_llama : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + bool use_rope = arch == LLM_ARCH_LLAMA4 + ? (il + 1) % hparams.no_rope_layer_interval != 0 + : true; + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -4246,25 +4264,39 @@ struct llm_build_llama : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( + if (use_rope) { + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + // TODO: support temperature tuning (attn_temperature_tuning) + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + if (arch == LLM_ARCH_LLAMA4 && use_rope) { + // Llama4TextL2Norm + // TODO @ngxson : the 128E model does not use qk_norm + Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6); + Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); } if (il == n_layer - 1) { @@ -4282,7 +4314,7 @@ struct llm_build_llama : public llm_graph_context { ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // feed-forward network + // feed-forward network (non-MoE) if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_norm(ffn_inp, @@ -4297,6 +4329,35 @@ struct llm_build_llama : public llm_graph_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); + + } else if (arch == LLM_ARCH_LLAMA4) { + // llama4 MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + + // Shared experts + cur = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_moe_shexp", il); + } else { // MoE branch cur = build_norm(ffn_inp, @@ -12091,6 +12152,7 @@ llm_graph_result_ptr llama_model::build_graph( switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: @@ -12440,6 +12502,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 5d5cafbea1f1b..0feabd95aaf2b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1616,7 +1616,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; } else if ( From f6d8e753263cacc60aaa0aa902f947ec6ead5416 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 00:33:16 +0200 Subject: [PATCH 03/20] clean up a bit --- src/llama-hparams.h | 1 + src/llama-model.cpp | 26 ++++++++++++++++++++------ src/llama-model.h | 1 + 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 361756907e18f..1844b4373f9e9 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -116,6 +116,7 @@ struct llama_hparams { uint32_t interleave_moe_layer_step = 2; // TODO read from gguf uint32_t no_rope_layer_interval = 4; // TODO read from gguf uint32_t attn_temperature_tuning = 4; // TODO read from gguf + uint32_t floor_scale = 8192; // TODO read from gguf // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fa447d343fc52..7f9b63d966c97 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -90,6 +90,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_27B: return "27B"; case LLM_TYPE_290B: return "290B"; + case LLM_TYPE_17B_16E: return "17Bx16E"; default: return "?B"; } } @@ -524,10 +525,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA4: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); if (hparams.n_expert == 8) { switch (hparams.n_layer) { @@ -552,6 +551,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLAMA4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + hparams.f_attention_scale = 0.1; + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_17B_16E; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DECI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4266,10 +4276,10 @@ struct llm_build_llama : public llm_graph_context { if (use_rope) { Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, @@ -4278,6 +4288,10 @@ struct llm_build_llama : public llm_graph_context { ); } else { // TODO: support temperature tuning (attn_temperature_tuning) + // Problem: we are missing 2 things: + // - ggml_cast from I32 to F32 + // - ggml_floor + // Ref implementation: https://github.com/ml-explore/mlx-lm/blob/9df43c9863c28065fecf87c9be2c5fd7e6f3864c/mlx_lm/models/llama4.py#L122-L130 } cb(Qcur, "Qcur", il); diff --git a/src/llama-model.h b/src/llama-model.h index 91e6e8725acd2..46d799baadc62 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -86,6 +86,7 @@ enum llm_type { LLM_TYPE_57B_A14B, LLM_TYPE_27B, LLM_TYPE_290B, + LLM_TYPE_17B_16E, // llama4 Scout }; struct llama_layer_posnet { From 1fb1888a8b3f5e606953c87256de7282b46e89f2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 00:35:51 +0200 Subject: [PATCH 04/20] fix tokenizer conversion --- convert_hf_to_gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0cab6cc937789..3b12a4b3f6c9c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1778,6 +1778,9 @@ def __init__(self, *args, **kwargs): self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] + def set_vocab(self): + self._set_vocab_gpt2() + def set_gguf_parameters(self): super().set_gguf_parameters() # TODO @ngxson : this is for testing, will be cleaned up later From 869d7d979c3358787d950bfc553906d1cbd5d56d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 08:37:41 +0200 Subject: [PATCH 05/20] correct hparams --- src/llama-hparams.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 1844b4373f9e9..a7be073a019dc 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -113,7 +113,7 @@ struct llama_hparams { bool attn_soft_cap = false; // TODO @ngxson : variable names taken from python code, we can rename it later - uint32_t interleave_moe_layer_step = 2; // TODO read from gguf + uint32_t interleave_moe_layer_step = 1; // TODO read from gguf uint32_t no_rope_layer_interval = 4; // TODO read from gguf uint32_t attn_temperature_tuning = 4; // TODO read from gguf uint32_t floor_scale = 8192; // TODO read from gguf From 6ceae82e32e200fcbdeecc9c596faccbe48ffdc7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 08:48:49 +0200 Subject: [PATCH 06/20] try this --- src/llama-model.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7f9b63d966c97..ae34f06568b43 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4364,13 +4364,15 @@ struct llm_build_llama : public llm_graph_context { il); // Shared experts - cur = build_ffn(cur, + ggml_tensor * shexp_out = build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_moe_shexp", il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, cur, shexp_out); } else { // MoE branch From 7cfc2373ab63fc77bc382f55bffe4bee3141db35 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 09:02:40 +0200 Subject: [PATCH 07/20] fix shexp --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ae34f06568b43..bf51d454eee88 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4364,7 +4364,7 @@ struct llm_build_llama : public llm_graph_context { il); // Shared experts - ggml_tensor * shexp_out = build_ffn(cur, + ggml_tensor * shexp_out = build_ffn(ffn_inp, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, From edbaaf4640296eb1d3ef701af008efb84b10f631 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 09:04:29 +0200 Subject: [PATCH 08/20] ffn_inp_normed --- src/llama-model.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bf51d454eee88..e95f11af67468 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4346,12 +4346,12 @@ struct llm_build_llama : public llm_graph_context { } else if (arch == LLM_ARCH_LLAMA4) { // llama4 MoE - cur = build_norm(ffn_inp, + ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, + ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -4364,7 +4364,7 @@ struct llm_build_llama : public llm_graph_context { il); // Shared experts - ggml_tensor * shexp_out = build_ffn(ffn_inp, + ggml_tensor * shexp_out = build_ffn(ffn_inp_normed, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -4372,7 +4372,7 @@ struct llm_build_llama : public llm_graph_context { LLM_FFN_SILU, LLM_FFN_PAR, il); cb(shexp_out, "ffn_moe_shexp", il); - cur = ggml_add(ctx0, cur, shexp_out); + cur = ggml_add(ctx0, moe_out, shexp_out); } else { // MoE branch From a518c11eba755ad84f15061e070a7657eed82ea8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 09:08:12 +0200 Subject: [PATCH 09/20] chat template --- src/llama-chat.cpp | 14 +++++++++++++- src/llama-chat.h | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index dd27a381423df..721faa4e8147e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -61,6 +61,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, { "bailing", LLM_CHAT_TEMPLATE_BAILING }, + { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -174,6 +175,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_YANDEX; } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { return LLM_CHAT_TEMPLATE_BAILING; + } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { + return LLM_CHAT_TEMPLATE_LLAMA4; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -608,7 +611,16 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "ASSISTANT"; } - } else { + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { + // Llama 4 + for (auto message : chat) { + std::string role(message->role); + ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>"; + } + if (add_ass) { + ss << "<|header_start|>assistant<|header_end|>\n\n"; + } + } else { // template not supported return -1; } diff --git a/src/llama-chat.h b/src/llama-chat.h index 0e0bd772c2eac..34537ca21e46e 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -40,6 +40,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, LLM_CHAT_TEMPLATE_BAILING, + LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_UNKNOWN, }; From 46fe5cbf790b1fed10fa372b0fae36e9c26153bc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 11:47:44 +0200 Subject: [PATCH 10/20] clean up model conversion --- convert_hf_to_gguf.py | 11 +++-- gguf-py/gguf/gguf_writer.py | 3 ++ src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-hparams.h | 12 +++--- src/llama-model.cpp | 85 ++++++++++++++++++++++++++++--------- src/llama-model.h | 1 + 7 files changed, 82 insertions(+), 32 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3b12a4b3f6c9c..0620d00dfffbd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1774,7 +1774,7 @@ def __init__(self, *args, **kwargs): if "vision_config" in hparams: logger.info("Has vision encoder, but it will be ignored") self.has_vision = True - # hacky renaming + # IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] @@ -1783,16 +1783,15 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - # TODO @ngxson : this is for testing, will be cleaned up later - self.gguf_writer.add_uint32("llama4.interleave_moe_layer_step", self.hparams["interleave_moe_layer_step"]) - self.gguf_writer.add_uint32("llama4.no_rope_layer_interval", 4) # every 4th layer - self.gguf_writer.add_uint32("llama4.expert_feed_forward_length", self.hparams["intermediate_size_moe"]) + self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name = name.replace("language_model.", "") name = name.replace("feed_forward.", "mlp.") # a bit hacky for now name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now + # split the gate_up into gate and up if "gate_up_proj" in name: name_up = name.replace("gate_up_proj", "up_proj.weight") name_gate = name.replace("gate_up_proj", "gate_proj.weight") @@ -1802,7 +1801,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): (self.map_tensor_name(name_gate), gate_proj_weight), (self.map_tensor_name(name_up), up_proj_weight) ] - + if name.endswith("down_proj"): name += ".weight" data_torch = data_torch.transpose(-1, -2) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af8b388dfaba5..485550aad6da4 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -746,6 +746,9 @@ def add_wkv_head_size(self, size: int) -> None: def add_token_shift_count(self, count: int) -> None: self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count) + def add_interleave_moe_layer_step(self, value: int) -> None: + self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fbbba1b7773f0..ac997b963c85a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -115,6 +115,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 9345cb66699af..42e4a3ef95e35 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -119,6 +119,7 @@ enum llm_kv { LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index a7be073a019dc..9f52b5ab5be72 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -112,11 +112,13 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; - // TODO @ngxson : variable names taken from python code, we can rename it later - uint32_t interleave_moe_layer_step = 1; // TODO read from gguf - uint32_t no_rope_layer_interval = 4; // TODO read from gguf - uint32_t attn_temperature_tuning = 4; // TODO read from gguf - uint32_t floor_scale = 8192; // TODO read from gguf + uint32_t n_moe_layer_step = 0; + bool use_kq_norm = true; + // values below seems to be fixed on llama4 + uint32_t n_no_rope_layer_step = 4; + uint32_t n_attn_temp_tuning = 4; + uint32_t n_attn_temp_floor_scale = 8192; + float f_attn_temp_scale = 0.1; // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e95f11af67468..50be05fc8d315 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -90,7 +90,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_27B: return "27B"; case LLM_TYPE_290B: return "290B"; - case LLM_TYPE_17B_16E: return "17Bx16E"; + case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; + case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; default: return "?B"; } } @@ -555,11 +556,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - hparams.f_attention_scale = 0.1; + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_17B_16E; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_expert) { + case 16: type = LLM_TYPE_17B_16E; break; + case 128: type = LLM_TYPE_17B_128E; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_17B_128E) { + hparams.use_kq_norm = false; } } break; case LLM_ARCH_DECI: @@ -1643,7 +1649,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const auto tn = LLM_TN(arch); switch (arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA4: case LLM_ARCH_REFACT: case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: @@ -1661,8 +1666,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { - bool is_moe_layer = (i + 1) % hparams.interleave_moe_layer_step == 0; - auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -1688,8 +1691,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - int n_ff_exp = hparams.n_ff_exp; - if (n_expert == 0 || !is_moe_layer) { + if (n_expert == 0) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); @@ -1700,17 +1702,59 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } + } break; + case LLM_ARCH_LLAMA4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); + for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - // Shared expert branch (only used by llama 4 for now) - if (arch == LLM_ARCH_LLAMA4) { - const int64_t n_ff_shexp = n_ff_exp; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } } break; @@ -4234,7 +4278,7 @@ struct llm_build_llama : public llm_graph_context { ggml_tensor * inpSA = inpL; bool use_rope = arch == LLM_ARCH_LLAMA4 - ? (il + 1) % hparams.no_rope_layer_interval != 0 + ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; // norm @@ -4298,9 +4342,8 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - if (arch == LLM_ARCH_LLAMA4 && use_rope) { + if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { // Llama4TextL2Norm - // TODO @ngxson : the 128E model does not use qk_norm Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6); Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6); cb(Qcur, "Qcur_normed", il); diff --git a/src/llama-model.h b/src/llama-model.h index 46d799baadc62..0f18dac16733b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -87,6 +87,7 @@ enum llm_type { LLM_TYPE_27B, LLM_TYPE_290B, LLM_TYPE_17B_16E, // llama4 Scout + LLM_TYPE_17B_128E, // llama4 Maverick }; struct llama_layer_posnet { From ab91ab2f1a5f696feb0aea80a60b8407aec8648b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 12:01:44 +0200 Subject: [PATCH 11/20] add_bos --- convert_hf_to_gguf.py | 1 + src/llama-model.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0620d00dfffbd..9549900206b48 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1780,6 +1780,7 @@ def __init__(self, *args, **kwargs): def set_vocab(self): self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) def set_gguf_parameters(self): super().set_gguf_parameters() diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 50be05fc8d315..a5811d315c100 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4416,6 +4416,7 @@ struct llm_build_llama : public llm_graph_context { cb(shexp_out, "ffn_moe_shexp", il); cur = ggml_add(ctx0, moe_out, shexp_out); + cb(shexp_out, "ffn_moe_out_merged", il); } else { // MoE branch From f9c788df845d800af8e8ed8b4822b38b9c74a7d8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 14:03:14 +0200 Subject: [PATCH 12/20] add scale_before_ffn --- src/llama-graph.cpp | 14 +++++++++++--- src/llama-model.cpp | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9555201f8813e..0fee517656065 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -812,8 +812,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il) const { - int64_t n_embd = cur->ne[0]; - int64_t n_tokens = cur->ne[1]; + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + const bool scale_before_ffn = arch == LLM_ARCH_LLAMA4; ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); @@ -873,6 +874,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + + if (scale_before_ffn) { + cur = ggml_mul(ctx0, cur, weights); + } + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); @@ -900,7 +906,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - experts = ggml_mul(ctx0, experts, weights); + if (!scale_before_ffn) { + experts = ggml_mul(ctx0, experts, weights); + } // aggregate experts ggml_tensor * moe_out = nullptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5811d315c100..15ba6ea34f01d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4416,7 +4416,7 @@ struct llm_build_llama : public llm_graph_context { cb(shexp_out, "ffn_moe_shexp", il); cur = ggml_add(ctx0, moe_out, shexp_out); - cb(shexp_out, "ffn_moe_out_merged", il); + cb(cur, "ffn_moe_out_merged", il); } else { // MoE branch From e4012e622c3a29bb76b045d723642e043e218e17 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 14:19:20 +0200 Subject: [PATCH 13/20] fix order --- src/llama-graph.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0fee517656065..8b0eea89c463f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -875,13 +875,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); - if (scale_before_ffn) { - cur = ggml_mul(ctx0, cur, weights); - } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); + if (scale_before_ffn) { + up = ggml_mul(ctx0, up, weights); + } + ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(gate, "ffn_moe_gate", il); From ee06e9b710a3da49b1c9fec34741ea0ee44d25d9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 17:08:13 +0200 Subject: [PATCH 14/20] weight_before_ffn --- src/llama-graph.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8b0eea89c463f..d6ffce452fb0b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -814,7 +814,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int il) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; - const bool scale_before_ffn = arch == LLM_ARCH_LLAMA4; + const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); @@ -875,13 +875,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + if (weight_before_ffn) { + ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); + repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] + cur = ggml_mul(ctx0, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); - if (scale_before_ffn) { - up = ggml_mul(ctx0, up, weights); - } - ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(gate, "ffn_moe_gate", il); @@ -906,7 +909,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - if (!scale_before_ffn) { + if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); } From f8f1bd4d211d0d2e8220f515cbe5d14224535f2c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 18:32:28 +0200 Subject: [PATCH 15/20] llm_graph_input_attn_temp --- src/llama-graph.cpp | 16 ++++++++++++++++ src/llama-graph.h | 17 +++++++++++++++++ src/llama-hparams.h | 1 - src/llama-model.cpp | 18 ++++++++++++------ 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d6ffce452fb0b..2a74c1b77cadc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -59,6 +59,22 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { + if (ubatch->pos && attn_scale) { + const int64_t n_tokens = ubatch->n_tokens; + + std::vector attn_scale_data(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const float pos = ubatch->pos[i]; + attn_scale_data[i] = std::log( + std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0 + ) * f_attn_temp_scale + 1.0; + } + + ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale)); + } +} + void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { const int64_t n_tokens = ubatch->n_tokens; diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..4a88bbfc48281 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -100,6 +100,23 @@ class llm_graph_input_pos : public llm_graph_input_i { const int64_t n_pos_per_token = 1; }; +// temperature tuning, used by llama4 +class llm_graph_input_attn_temp : public llm_graph_input_i { +public: + llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) + : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} + virtual ~llm_graph_input_attn_temp() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * attn_scale = nullptr; // F32 [n_batch] + + const int64_t n_pos_per_token = 1; + + const uint32_t n_attn_temp_floor_scale; + const float f_attn_temp_scale; +}; + class llm_graph_input_pos_bucket : public llm_graph_input_i { public: llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9f52b5ab5be72..f1f93deb6d1e7 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -116,7 +116,6 @@ struct llama_hparams { bool use_kq_norm = true; // values below seems to be fixed on llama4 uint32_t n_no_rope_layer_step = 4; - uint32_t n_attn_temp_tuning = 4; uint32_t n_attn_temp_floor_scale = 8192; float f_attn_temp_scale = 0.1; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 15ba6ea34f01d..37ce51a023155 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4271,6 +4271,16 @@ struct llm_build_llama : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); + // temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + if (arch == LLM_ARCH_LLAMA4) { + auto inp = std::make_unique(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); + inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); + ggml_set_input(inp_attn_scale); + inp->attn_scale = inp_attn_scale; + res->add_input(std::move(inp)); + } + auto * inp_attn = build_attn_inp_kv_unified(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -4330,12 +4340,8 @@ struct llm_build_llama : public llm_graph_context { n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - } else { - // TODO: support temperature tuning (attn_temperature_tuning) - // Problem: we are missing 2 things: - // - ggml_cast from I32 to F32 - // - ggml_floor - // Ref implementation: https://github.com/ml-explore/mlx-lm/blob/9df43c9863c28065fecf87c9be2c5fd7e6f3864c/mlx_lm/models/llama4.py#L122-L130 + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); } cb(Qcur, "Qcur", il); From e6a2809c2d42042cb5e64052117be1e36af53b83 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 19:06:01 +0200 Subject: [PATCH 16/20] add chunk attn mask --- src/llama-graph.cpp | 12 ++++++++++-- src/llama-hparams.h | 1 + src/llama-model.cpp | 5 +++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 2a74c1b77cadc..5a948f0849426 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -474,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { } // may need to cut off old tokens for sliding window + // TODO @ngxson : the check for n_attn_chunk is temporary, need to optimize it if (data_swa) { - if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } } data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f1f93deb6d1e7..4e0b57190a3a7 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -114,6 +114,7 @@ struct llama_hparams { uint32_t n_moe_layer_step = 0; bool use_kq_norm = true; + uint32_t n_attn_chunk = 0; // values below seems to be fixed on llama4 uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 37ce51a023155..c46740c8c89c1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -557,6 +557,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + // hack: we use SWA to store the chunked attn mask + // luckily, the n_swa_pattern is the same as chunked layer pattern: 3 chunked - 1 full + hparams.n_swa_pattern = 4; + hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.n_swa = 1; // unused, added to trigger the SWA switch (hparams.n_expert) { case 16: type = LLM_TYPE_17B_16E; break; From af1968c3547c7970d68590eb582d94f3a1b84ac9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 21:26:25 +0200 Subject: [PATCH 17/20] build_inp_attn_scale() --- src/llama-graph.cpp | 13 +++++++++++++ src/llama-graph.h | 1 + src/llama-model.cpp | 6 +----- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5a948f0849426..be6b1d82889ed 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1024,6 +1024,19 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_attn_scale() const { + auto inp = std::make_unique(n_pos_per_token()); + + auto & cur = inp->attn_scale; + + cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + ggml_tensor * llm_graph_context::build_inp_out_ids() const { auto inp = std::make_unique(hparams, cparams, n_outputs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 4a88bbfc48281..5b6618f9e55f1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -487,6 +487,7 @@ struct llm_graph_context { ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; ggml_tensor * build_inp_pos() const; + ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c46740c8c89c1..2d02f19f16097 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4279,11 +4279,7 @@ struct llm_build_llama : public llm_graph_context { // temperature tuning ggml_tensor * inp_attn_scale = nullptr; if (arch == LLM_ARCH_LLAMA4) { - auto inp = std::make_unique(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); - inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); - ggml_set_input(inp_attn_scale); - inp->attn_scale = inp_attn_scale; - res->add_input(std::move(inp)); + inp_attn_scale = build_inp_attn_scale(); } auto * inp_attn = build_attn_inp_kv_unified(); From 09eba6a55492ac72a505a3719243d1aef491a1ab Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 21:36:47 +0200 Subject: [PATCH 18/20] add comment about ggml_repeat --- src/llama-graph.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index be6b1d82889ed..06c6d3279dfff 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -900,6 +900,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { + // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] cur = ggml_mul(ctx0, repeated, weights); @@ -935,6 +936,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); + cb(cur, "ffn_moe_weighted", il); } // aggregate experts From b28cd9ca84f6ce88f3f20339a3e99e1e6aa64765 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 22:10:18 +0200 Subject: [PATCH 19/20] clarify comments --- src/llama-graph.cpp | 2 +- src/llama-model.cpp | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 06c6d3279dfff..0d187aba037a0 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -474,7 +474,7 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { } // may need to cut off old tokens for sliding window - // TODO @ngxson : the check for n_attn_chunk is temporary, need to optimize it + // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask" if (data_swa) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2d02f19f16097..4546e9cf9ba96 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -557,11 +557,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - // hack: we use SWA to store the chunked attn mask - // luckily, the n_swa_pattern is the same as chunked layer pattern: 3 chunked - 1 full - hparams.n_swa_pattern = 4; + hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.n_swa = 1; // unused, added to trigger the SWA + hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later switch (hparams.n_expert) { case 16: type = LLM_TYPE_17B_16E; break; From d3e67f98beee9ca6cf90d0634f8f767d0f0ca753 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 7 Apr 2025 22:14:18 +0200 Subject: [PATCH 20/20] fix build --- src/llama-graph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0d187aba037a0..c3469177e091c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1027,7 +1027,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { } ggml_tensor * llm_graph_context::build_inp_attn_scale() const { - auto inp = std::make_unique(n_pos_per_token()); + auto inp = std::make_unique(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); auto & cur = inp->attn_scale;