Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Granite Four #13550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 132 commits into
base: master
Choose a base branch
from
Draft

Granite Four #13550

wants to merge 132 commits into from

Conversation

gabe-l-hart
Copy link
Contributor

@gabe-l-hart gabe-l-hart commented May 14, 2025

Description

This PR is the end-point for architecture support for Granite 4.0 (#13269 . It incorporates a number of changes from other in-flight branches that will need to be merged first:

Additionally, this PR replaces some work done on other PRs / branches:

Outstanding Questions

Besides the upstream PRs, there are a few questions to answer before this PR is merge ready:

  • This PR contains several changes to llama-kv-cache beyond those in feat: Hybrid unified/recurrent cache #13276, but they depend on the addition of hparams.recurrent_layer_arr which is only populated correctly if there is a valid model architecture to check against. Should I move all of these changes to the hybrid cache PR or keep them here where the model architectures become real?
  • Is there a more efficient way to implement hparams.recurrent_layer_arr? Using a max-layer-size std::array doesn't feel quite right.
  • There are still some numerical differences between the attention outputs when running Bamba and granite-4.0-tiny-shared-preview on this branch vs the respective draft branches, so I need to determine if this is due to changes in the attention implementation (ie "working as expected") or a bug somewhere.
  • The use of dymamic_cast to get the right cache type could be expensive (though it's likely negligible relative to the tensor math). Should we do something more clever to handle different cache types in llama-graph?
  • The switch statement for determining the type of KV cache to allocate in llama-model.cpp seems redundant with llama_model_is_recurrent and llama_model_is_hybrid. Should we use those functions instead and eliminate the duplicate logic and additional place to tweak for new recurrent / hybrid models?

Testing

To test out this branch, I've been using the following models:

Details

This PR has a lot of changes in it, some of which are isolated in the prereq-PRs above. In addition to the general mamba2 and llama_kv_cache_hybrid changes, this PR does the following:

python side

  • Add conversion support for BambaForCausalLM and GraniteMoeHybridForCausalLM
    • This includes one small tweak to gguf_writer.py that allows duplicate key/value pairs through add_key_value if (and only if) they match both value and type with the existing key. This is a convenience for hybrid models so that the converter doesn't need to rewrite the hparam conversion from multiple parents.
    • This also adds the new HybridAttention section under Keys in constants.py to hold attention.layer_indices. OPEN QUESTION: Should this just go under Attention?

c++ side

  • Add a new public API function llama_model_is_hybrid akin to llama_model_is_recurrent
    • I also split up both this function and llama_model_is_recurrent into llm_arch_is_* implemented in llama-arch.* and llama_model_is_* implemented in llama-model.*. This was done so that they could be used during model initialization before the model itself can be passed as the argument, specifically to determine how to populate hparams.recurrent_layer_arr (see below).
  • Add hparams.recurrent_layer_arr and support parsing it
    • The current implementation pre-allocates it as a fixed-length array which doesn't feel quite right.
  • Add an optional layer id to hparams.n_embd_k_s / hparams.n_embd_v_s
    • This is done because for hybrid models, the values may be different by layer.
    • I plumbed through as many usages of these methods as I could find to properly pass the layer index, but there are some places where it's not available which default to layer 0. This should be fine since none of those places interact with the hybrid caching.
  • Add hparams.recurrent_layer(uint32_t) to check whether a given layer is recurrent
  • Model name/param/arch plumbing for bamba and granitemoeshared in llama-arch.* (the boring part!)
  • (possibly breaking) Add hparams as an additional argument to the llama_model.create_memory method
    • This is done so the hparams can be given to the cache construction and used to determine which layers are recurrent for hybrid cache creation
  • In llama-graph, anywhere that a specific cache type needs to be fetched, it is grabbed using new methods get_recurrent_cache / get_unified_cache. These methods use dynamic_cast to handle both non-hybrid caches and hybrid caches.
  • Add support for instantiating the hybrid cache in llama-model.cpp
  • Add model support for bamba and granitemoehybrid in llama-model
    • Most of this is "business as usual," but that breaks down when trying to avoid code duplication for the hybrid architecture
    • To avoid code duplication, I hoisted build_mamba_layer / build_mamba2_layer from llm_build_mamba and build_attention_layer / build_layer_ffn from llm_build_granite into static methods on their respective classes. This makes for some gross function signatures where member data needs to be explicitly passed, but it allows the hybrid model architecture(s) to use these methods without complex inheritance.
    • I tried an alternative route using diamond inheritance, but this would have required some kind of "don't actually initialize the graph" switch in the parent model builders' constructors to avoid trying to build the parent model graphs during initialization of the hybrid class.

compilade added 30 commits April 3, 2024 20:47
This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.
* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot
* ggml : simplify SSM-related operators

* llama : make recurrent state slot allocation contiguous

* llama : adapt internal uses of batches to llama_ubatch
This reduces overhead when running hellaswag
on thousands of sequences with very small 100k params Mamba models.
This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.
This removes the need for ggml_ssm_conv!!!
But performance seems slighly worse on my system,
especially for prompt processing.
Maybe ggml_mul_mat isn't optimized for small row sizes?
More performance testing is necessary until GGML_OP_SSM_CONV is removed.

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.
* ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors

The implementation already supported it,
and this makes Mamba's conv step slightly faster.
This can be changed back later if the name change is wrong.
I was renaming the functions anyway to generalize kv-cache-related
functions to hybrid and recurrent model architectures.
I think llama_past is a better name than llama_cache for a combined
kv cache and recurrent state cache, because the states it contains
pretty much always come before the newly-added ones for any particular
sequence. Also 'llama_past_clear' sounds more obvious in what it does
than 'llama_kv_cache_clear'. The future is what the models generate.
(For embeddings, the kv cache isn't really used anyway)

Still, I'm open to better suggestions.
* origin/master:
memory : correctly handle failure in apply() (ggml-org#14438)
* origin/master:
Add Vulkan images to docker.md (ggml-org#14472)
CANN: update aclnnGroupedMatmulV2 to aclnnGroupedMatmulV3 (ggml-org#14411)
vulkan: Split large mul_mat_id to fit in shared memory (ggml-org#14451)
add GELU_ERF (ggml-org#14455)
ggml : remove trailing whitespace (#0)
sync : ggml
ggml-cpu : "align corners" for bilinear upscale/downscale (ggml/1285)
ggml-quants : rename best_mad to best_error (ggml/1283)
opencl : add GEGLU, REGLU, SWIGLU (ggml-org#14456)
Add Conv2d for CPU (ggml-org#14388)
* origin/master:
llama : initial Mamba-2 support (ggml-org#9126)
sync : ggml
ggml : add version function to get lib version (ggml/1286)
Set RPATH to "@loader_path" / "$ORIGIN" to ensure executables and dynamic libraries search for dependencies in their origin directory. (ggml-org#14309)
CUDA: add softmax broadcast (ggml-org#14475)
CUDA: broadcasting for FlashAttention mask (ggml-org#14500)
vulkan: support softmax/FA batch and broadcast (ggml-org#14449)
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (ggml-org#14435)
opencl : fix possible buffer overflow in dump_tensor (ggml-org#14490)
simple-chat : fix context-exceeded condition (ggml-org#14494)
opencl : skip empty nodes on cgraph compute (ggml-org#14491)
opencl : update upscale to support align corners (ggml-org#14488)
ci : add OpenCL to labeler workflow (ggml-org#14496)
github : add OpenCL backend to issue templates (ggml-org#14492)
ggml : Callback before abort (ggml-org#14481)
ci : disable fast-math for Metal GHA CI (ggml-org#14478)
@gabe-l-hart gabe-l-hart marked this pull request as ready for review July 2, 2025 17:46
@gabe-l-hart
Copy link
Contributor Author

@compilade, huge thanks for pushing #9126 over the line! With that merged, this is ready for full review (cc @ggerganov). This will be the first real instance of the hybrid recurrent cache implementation.

@@ -4875,6 +4875,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
Copy link
Contributor Author

@gabe-l-hart gabe-l-hart Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled these into the class so that they can be set differently by derived conversion classes and then used in the common methods below

EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_MOE_HYBRID = auto()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been one of the most annoying changes keeping this branch up to date: The GRANITE_MOE_HYBRID name is two characters longer than the previous longest name, so to keep vertical alignment, it changes the indentation of all values (here and in llama-arch.cpp).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
cur = build_granite_attention_layer(
Copy link
Contributor Author

@gabe-l-hart gabe-l-hart Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had originally extracted these as standalone methods so that I could reuse them in the hybrid implementation. Ultimately, any inheritance / static method / mixin approach I tried felt too tangled, so tangled, so I went back to simply duplicating these methods in the hybrid model. I left these separated out here for the symmetry and encapsulation, but I could also revert this set of changes to llm_build_granite to keep the changeset smaller.

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}

ggml_tensor * build_mamba2_layer(
Copy link
Contributor Author

@gabe-l-hart gabe-l-hart Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy-paste from llm_build_mamba. Per the other comment, it got too tangled to try to reliably reuse these across model builders. That said, I would still love to find a way to avoid this kind of duplication if there's appetite.

Copy link
Collaborator

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabe-l-hart I might have found a way to avoid this kind of duplication, see src/llama-model.cpp and src/llama-graph.cpp in #7531

llm_graph_context_mamba is a child class of llm_graph_context with Mamba-specific layer builders

llama.cpp/src/llama-model.cpp

Lines 9883 to 9894 in 908e655

struct llm_graph_context_mamba : public llm_graph_context {
llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
ggml_tensor * build_mamba_layer(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
const llama_model & model,
const llama_ubatch & ubatch,
int il) {
const auto * mctx_cur = inp->mctx;

llm_graph_context_mamba is the parent class of llm_build_mamba and llm_build_jamba. Not sure if that would still be appropriate with multiple-inheritance, though that wasn't necessary for Jamba.

The methods could potentially be moved to llm_graph_context (in src/llama-graph.cpp), but I preferred to keep model-specific graph building methods in src/llama-model.cpp, at least for now.

llama.cpp/src/llama-model.cpp

Lines 10156 to 10157 in 908e655

struct llm_build_mamba : public llm_graph_context_mamba {
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {

Note that I've also removed build_inp_mem_hybrid and llm_graph_input_hybrid in favor of directly using the recurrent and self-attention input builders separately. This is relatively clean, I think.

build_inp_rs and build_inp_attn_kv_unified accept an optional mctx override argument.

llama.cpp/src/llama-model.cpp

Lines 10213 to 10259 in 908e655

struct llm_build_jamba : public llm_graph_context_mamba {
llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
ggml_tensor * cur;
ggml_tensor * inpL;
// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);
const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
const int64_t n_head_kv = hparams.n_head_kv(il);
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
if (n_head_kv == 0) {
cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il);
} else {
// Attention
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// No RoPE :)
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);

This makes use of the fact that mctx is stored in inp_rs and inp_attn already, and so build_rs and build_attn were changed to use that instead of trying to cast llm_graph_context::mctx again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llm_graph_context_mamba is a child class of llm_graph_context with Mamba-specific layer builders

Ah, very nice. This is quite close to something else I tried earlier using the mixin approach. You can see it here:

(sidebar, what is the magical GH markdown syntax to get those nice multi-line reference blocks to render inline??)

I think the only real difference here is that the mixins couldn't inherit from llm_graph_context in order to support multiple inheritance and avoid double-initialization. Unfortunately, this meant that mixin classes had to hold a pointer to an llm_graph_context (the eventual child) and the implementations of the shared methods needed to to self->... everywhere instead of just using shared member methods from llm_graph_context.

The other approach I tried was to simply make the layer-builder methods static so that they can simply be called by other classes. I don't have a good point in history to show this because I stopped being super careful about maintaining a working version of this during rebasing, but you can see the gist of it here for the static builder method and here for calling it. On that hash, I think I accidentally dropped the static on the mamba layer builders (it was there at some point), but the idea was the same. Just like in the mixin approach, in order to avoid multiple-inheritance with multi-initialization, the "fix" was to pass this as an argument to the static methods and have them do self->... everywhere in their body.

In both cases, the tradeoff I struggled to make was forcing the implementation of the layer builder to know that it would be used by hybrid models. I'm certainly not against this, but it felt like a bit of an inversion of responsibility (the "child" in the hybrid relationship becoming the "parent" in the implementation hierarchy). I'm very open to revisiting this, especially the mixin approach since I think that is probably the cleanest way to handle this from a design pattern perspective. The only other thought I had was to use a templated generic hybrid builder, but I know templates are generally avoided in this project.


Note that I've also removed build_inp_mem_hybrid and llm_graph_input_hybrid in favor of directly using the recurrent and self-attention input builders separately. This is relatively clean, I think.

build_inp_rs and build_inp_attn_kv_unified accept an optional mctx override argument.

Nice, this feels much better. I had also tried something like this, but @ggerganov had suggested it was best to start by following the build_attn pattern strictly. I like the idea of revisiting this now that we've sorted out some of the other issues with the hybrid cache.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The methods could potentially be moved to llm_graph_context (in src/llama-graph.cpp), but I preferred to keep model-specific graph building methods in src/llama-model.cpp, at least for now.

I thought about this one too, and I agree that putting mamba or granite methods in the upstream base class is not correct. I also tried using virtual inheritance to enable diamond-style multiple inheritance without multiple-initialization of the common base class, but wasn't able to get it working (I can't quite remember why now).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sidebar, what is the magical GH markdown syntax to get those nice multi-line reference blocks to render inline??)

It's a permalink to ranges of lines on its own Markdown line. I think it has to be from the same repository, though (so only code from https://github.com/ggml-org/llama.cpp (including branches) will have pretty previews here).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I got the mixin pattern while also using diamond inheritance so the mixins can inherit from llm_graph_context. The only change on top of the changes in #7531 is to make llm_graph_context_mamba use virtual inheritance from llm_graph_context (struct llm_graph_context_mamba : public virtual llm_graph_context {). This then means that all classes that inherit from one or more of these mixins need to explicitly initialize llm_graph_context as well as the mixin(s). This does result in some duplicate initialization since the initializer list and constructor of llm_graph_context will still be called once for each parent in the diamond (so twice if single inheritance, three times if diamond inheritance). This should be a negligible runtime hit (especially if we can avoid re-creating res since that's the only heap allocation), but it will still end up setting all of the members multiple times. If we can make this all compatible with #14482, then the hit should be trivial since it will not happen for every token.

Copy link
Contributor Author

@gabe-l-hart gabe-l-hart Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a full implementation of this with #7531 merged in https://github.com/gabe-l-hart/llama.cpp/blob/GraniteFourWithJamba. Depending on how we want to order the merges, I'm happy to update this branch to include that refactor and put it behind Jamba for merge.

Copy link
Contributor Author

@gabe-l-hart gabe-l-hart Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dug a bit more and I actually think virtual inheritance is exactly what we want and there will be zero duplicate initialization (though there would be duplicate construction if there were a body to the constructor for llm_graph_context). Here's a dummy sample:

virtual-inheritance.cpp
// Example program
#include <iostream>
#include <string>

struct Member {
    Member(int i_) : i(i_) {
        std::cout << "Constructing Member with " << i << std::endl;
    }
    
    const int i;
};

struct Base {
    Base(int i) : m(i) {}

    Member m;
};

struct MixinA : public virtual Base {
    MixinA() : Base(1) {}
    
    void do_a() const {
        std::cout << "Doing A: " << m.i << std::endl;
    }
};

struct MixinB : public virtual Base {
    MixinB() : Base(2) {}
    
    void do_b() const {
        std::cout << "Doing B: " << m.i << std::endl;
    }
};

struct Kid : public MixinA, public MixinB {
    Kid() : Base(3), MixinA(), MixinB() {}
    
    void doit() const {
        do_a();
        do_b();
    }
};

int main()
{
    Kid k;
    k.doit();
}

When run, this outputs:

Constructing Member with 3
Doing A: 3
Doing B: 3

This indicates that the explicit initialization of Base in Kid is used, but that the initialization of Base in MixinA and MixinB is ignored which is exactly what we want here.

…d_mamba

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
gabe-l-hart and others added 9 commits July 2, 2025 12:49
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
* origin/master:
gguf-py : add support for chat template jinja files (ggml-org#14508)
* origin/master:
Fix conditional enabling following arch checks for ggml-sycl (ggml-org#14504)
convert : correct gemma 3n conversion (ggml-org#14450)
kv-cache : use ggml_set_rows (ggml-org#14285)
ggml : fix FA mask dim 2 and 3 (ggml-org#14505)
ggml : remove kompute backend (ggml-org#14501)
CUDA: add dynamic shared mem to softmax, refactor general usage (ggml-org#14497)
…o GraniteFourWithJamba

* origin/compilade/refactor-kv-cache: (32 commits)
convert : fix jamba conv1d shape squeezing
llama : partially apply clang-format style
llama : remove implicit recurrent state rollbacks
llama : begin renaming llama_past back to llama_kv_cache
llama : use unused n_embd_k_gqa in k_shift
llama : fix mixed signedness comparison
convert_hf : fix Jamba conversion
llama : session saving and reloading for hybrid models
mamba : fix non-contiguous usage of ggml_silu
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
llama : rename llama_cache to llama_past
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
llama : fix .base() compilation error on Windows
llama : use im2col and mul_mat to perform convolution for Mamba
llama : avoid copies for simple batch splits
llama : fix edge case finding batch seq_id of split recurrent cell
llama : minimize swaps when reordering logits
llama : fix batch split output count for embeddings
llama : use equal-sequence-length sub-batches for recurrent models
llama : sequence-length-aware batch splitting
...
…id inputs

Branch: GraniteFourWithJamba

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFourWithJamba

Signed-off-by: Gabe Goodhart <[email protected]>
…as mixins

The key is for the mixin classes (llm_graph_context_mamba,
llm_graph_context_granite) to use virtual inheritance from
llm_graph_context. This allows the common members to exist only once in the
class hierarchy. The downside is that llm_graph_context will be
re-initialized once for each parent (ie 2x for single mixin, 3x for two
mixins, etc...).

Branch: GraniteFourWithJamba

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Contributor Author

Given the changes in #7531 that relate to this same architecture and the draft I have with the mixin pattern and virtual inheritance, I'm bumping this one back behind Jamba in merge order and will update this PR to include the changes in #7531.

@gabe-l-hart gabe-l-hart marked this pull request as draft July 3, 2025 19:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes server testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants