@@ -680,14 +680,14 @@ struct llm_graph_context {
680
680
//
681
681
682
682
ggml_tensor * build_attn_mha (
683
- ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684
- ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685
- ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686
- ggml_tensor * kq_b,
687
- ggml_tensor * kq_mask,
688
- ggml_tensor * sinks,
689
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690
- float kq_scale) const ;
683
+ ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684
+ ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685
+ ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686
+ ggml_tensor * kq_b,
687
+ ggml_tensor * kq_mask,
688
+ ggml_tensor * sinks, // [n_head_q]
689
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690
+ float kq_scale) const ;
691
691
692
692
llm_graph_input_attn_no_cache * build_attn_inp_no_cache () const ;
693
693
@@ -699,6 +699,7 @@ struct llm_graph_context {
699
699
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
700
700
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
701
701
ggml_tensor * kq_b,
702
+ ggml_tensor * sinks, // [n_head_q]
702
703
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
703
704
float kq_scale,
704
705
int il) const ;
@@ -713,6 +714,7 @@ struct llm_graph_context {
713
714
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
714
715
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
715
716
ggml_tensor * kq_b,
717
+ ggml_tensor * sinks, // [n_head_q]
716
718
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
717
719
float kq_scale,
718
720
int il) const ;
@@ -728,21 +730,8 @@ struct llm_graph_context {
728
730
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
729
731
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
730
732
ggml_tensor * kq_b,
731
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
732
- float kq_scale,
733
- int il) const ;
734
-
735
- // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736
- ggml_tensor * build_attn_with_sinks (
737
- llm_graph_input_attn_kv_iswa * inp,
738
- ggml_tensor * wo,
739
- ggml_tensor * wo_b,
740
- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
742
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
743
- ggml_tensor * kq_b,
744
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
745
733
ggml_tensor * sinks, // [n_head_q]
734
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
746
735
float kq_scale,
747
736
int il) const ;
748
737
@@ -756,6 +745,7 @@ struct llm_graph_context {
756
745
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
757
746
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
758
747
ggml_tensor * kq_b,
748
+ ggml_tensor * sinks, // [n_head_q]
759
749
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
760
750
float kq_scale,
761
751
int il) const ;
0 commit comments