Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 54854e1

Browse files
committed
add tree attention support for blackwell
Signed-off-by: qgai <[email protected]>
1 parent d246f62 commit 54854e1

File tree

15 files changed

+931
-60
lines changed

15 files changed

+931
-60
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
201201
// Medusa mode will have multiple query tokens.
202202
xqaParams.multi_query_tokens = mIsSpecDecodingEnabled && mUseSpecDecoding;
203203
xqaParams.is_spec_dec_tree = mIsSpecDecTree;
204+
xqaParams.layer_idx = generationsParams.layer_idx;
204205

205206
if (mKVCacheQuantMode.hasInt8KvCache())
206207
{
@@ -278,6 +279,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
278279
xqaParams.spec_decoding_is_generation_length_variable
279280
= generationsParams.spec_decoding_is_generation_length_variable;
280281
xqaParams.spec_decoding_max_generation_length = generationsParams.spec_decoding_max_generation_length;
282+
xqaParams.spec_decoding_bl_tree_mask_offset = generationsParams.spec_decoding_bl_tree_mask_offset;
283+
xqaParams.spec_decoding_bl_tree_mask = generationsParams.spec_decoding_bl_tree_mask;
284+
xqaParams.spec_bl_tree_first_sparse_mask_offset_kv = generationsParams.spec_bl_tree_first_sparse_mask_offset_kv;
281285
xqaParams.mrope_position_deltas = generationsParams.mrope_position_deltas;
282286

283287
xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr;
@@ -2284,6 +2288,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22842288
// self attn
22852289
XQAParams xqaParams{};
22862290
this->template convertMMHAParamsToXQAParams<T, KVCacheBuffer>(xqaParams, params, /*forConfigurePlugin=*/false);
2291+
22872292
if (mEnableXQA && mXqaDispatcher->shouldUse(xqaParams))
22882293
{
22892294
TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase.");
@@ -2908,11 +2913,12 @@ int AttentionOp::initialize() noexcept
29082913
{
29092914
fixedParams.outputDataType = DATA_TYPE_E4M3;
29102915
}
2911-
if (mIsSpecDecodingEnabled)
2916+
if (mIsSpecDecodingEnabled && !mUseTllmGen)
29122917
{
29132918
fixedParams.outputDataType = DATA_TYPE_E4M3;
29142919
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
29152920
}
2921+
29162922
fixedParams.numQHeads = mNumAttnHeads;
29172923
fixedParams.numKvHeads = mNumAttnKVHeads;
29182924
fixedParams.numTokensPerBlock = mTokensPerBlock;

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,12 @@ class AttentionOp
224224
int32_t const* spec_decoding_generation_lengths = nullptr;
225225
bool spec_decoding_is_generation_length_variable = false;
226226
int32_t spec_decoding_max_generation_length = 1;
227+
int64_t* spec_decoding_bl_tree_mask_offset = nullptr;
228+
uint32_t* spec_decoding_bl_tree_mask = nullptr;
229+
int32_t* spec_bl_tree_first_sparse_mask_offset_kv = nullptr;
227230
// optional when fuse_fp4_quant is enabled
228231
int32_t start_token_idx_sf = 0;
232+
int32_t layer_idx = 0;
229233
};
230234

231235
template <typename T, typename KVCacheBuffer>

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,18 @@ struct XQAParams
5050
int32_t sink_token_length = 0;
5151
int max_past_kv_length = 0;
5252
void const* qkv_bias;
53-
int32_t const* sequence_lengths; //
54-
int32_t const* context_lengths; // maybe not used now
55-
void const* alibi_slopes; // maybe not used now
56-
float const* rotary_embedding_inv_freq_cache; // precomputed rotary inv freq
53+
int32_t const* sequence_lengths; //
54+
int32_t const* context_lengths; // maybe not used now
55+
void const* alibi_slopes; // maybe not used now
56+
float const* rotary_embedding_inv_freq_cache; // precomputed rotary inv freq
5757
int32_t const* spec_decoding_packed_mask;
58-
int const* spec_decoding_position_offsets; // for position embedding.
59-
int const* spec_decoding_generation_lengths; // variable input lengths.
60-
bool spec_decoding_is_generation_length_variable; // whether the generation lengths actually vary
61-
int32_t spec_decoding_max_generation_length; // max possible input length
58+
int const* spec_decoding_position_offsets; // for position embedding.
59+
int const* spec_decoding_generation_lengths; // variable input lengths.
60+
bool spec_decoding_is_generation_length_variable; // whether the generation lengths actually vary
61+
int32_t spec_decoding_max_generation_length; // max possible input length
62+
int64_t* spec_decoding_bl_tree_mask_offset; // for blackwell spec-dec tree mask offset
63+
uint32_t* spec_decoding_bl_tree_mask; // for blackwell spec-dec tree mask
64+
int32_t* spec_bl_tree_first_sparse_mask_offset_kv; // for blackwell spec-dec tree first sparse mask offset kv
6265
int32_t const* mrope_position_deltas = nullptr;
6366

6467
// almost copy from GPTAttentionPluginCommon.
@@ -115,6 +118,8 @@ struct XQAParams
115118
bool use_sparse_attention = false;
116119

117120
cudaStream_t stream = 0;
121+
// layer index
122+
int32_t layer_idx = 0;
118123

119124
std::string toString() const
120125
{
@@ -149,6 +154,9 @@ struct XQAParams
149154
<< "spec_decoding_is_generation_length_variable: "
150155
<< (spec_decoding_is_generation_length_variable ? "true" : "false") << std::endl
151156
<< "spec_decoding_max_generation_length: " << spec_decoding_max_generation_length << std::endl
157+
<< "spec_decoding_bl_tree_mask_offset: " << spec_decoding_bl_tree_mask_offset << std::endl
158+
<< "spec_decoding_bl_tree_mask: " << spec_decoding_bl_tree_mask << std::endl
159+
<< "spec_bl_tree_first_sparse_mask_offset_kv: " << spec_bl_tree_first_sparse_mask_offset_kv << std::endl
152160
<< "mrope_position_deltas: " << mrope_position_deltas << std::endl
153161
<< "generation_input_length: " << generation_input_length << std::endl
154162
<< "num_q_heads: " << num_q_heads << std::endl

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
#include "fmhaReduction.h"
3030
#include "fmhaRunnerParams.h"
3131
#include "kernelParams.h"
32+
#include "prepareCustomMask.h"
33+
#include "tensorrt_llm/kernels/kvCacheUtils.h"
3234
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
35+
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
3336

3437
namespace tc = tensorrt_llm::common;
3538

@@ -204,6 +207,11 @@ class TllmGenFmhaKernel
204207
selectKernelIter++;
205208
continue;
206209
}
210+
// Prepare custom mask for spec-decoding generation kernels.
211+
if (params.layer_idx == 0 && params.is_spec_dec_tree)
212+
{
213+
runPrepareCustomMask(kernelMeta, params, params.stream);
214+
}
207215

208216
// Prepare the kernel parameters.
209217
auto kernelParams = KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv);
@@ -518,9 +526,24 @@ class TllmGenFmhaKernel
518526
}
519527
else if (isGenerationKernel(params.mKernelType))
520528
{
521-
kernelType = (params.mNumHeadsQPerKv <= 16 && params.mHeadDimQk != 32)
522-
? FmhaKernelType::SwapsMmaAbForGeneration
523-
: FmhaKernelType::KeepsMmaAbForGeneration;
529+
if (params.is_spec_dec_tree)
530+
{
531+
532+
if (params.mNumHeadsQPerKv <= 16 && (params.mHeadDimQk == 64 || params.mHeadDimQk == 128))
533+
{
534+
kernelType = FmhaKernelType::KeepsMmaAbForGeneration;
535+
}
536+
else
537+
{
538+
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
539+
}
540+
}
541+
else
542+
{
543+
kernelType = (params.mNumHeadsQPerKv <= 16 && params.mHeadDimQk != 32)
544+
? FmhaKernelType::SwapsMmaAbForGeneration
545+
: FmhaKernelType::KeepsMmaAbForGeneration;
546+
}
524547
}
525548

526549
// The maximum number of headsQPerKv that the kernel can support in one Cta.
@@ -538,6 +561,10 @@ class TllmGenFmhaKernel
538561
{
539562
// Use the maxNumHeadsQPerKvInCta (tileSizeQ) = 64 for MLA high-throughput generation kernels.
540563
maxNumHeadsQPerKvInCta = isMlaGenKernel(params) ? 64 : 32;
564+
if (params.is_spec_dec_tree)
565+
{
566+
maxNumHeadsQPerKvInCta = 128;
567+
}
541568
TLLM_CHECK_WITH_INFO((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta
542569
|| params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0),
543570
"Not supported");

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ struct TllmGenFmhaRunnerParams
207207
void const* qkvPtr;
208208
// The attention sinks pointer (additional value per head in the denominator of the softmax).
209209
float const* attentionSinksPtr;
210+
// The general packed custom mask ptr which does not meet specific format for trtllm gen kernels.
211+
int32_t const* generalPackedCustoMaskPtr;
210212
// The custom mask ptr.
211-
uint32_t const* customMaskPtr;
213+
uint32_t* customMaskPtr;
212214
// The packed custom mask's offsets of each sequence.
213-
int64_t const* customMaskOffsetsPtr;
215+
int64_t* customMaskOffsetsPtr;
214216
// The first sparseMask offsets in the Kv sequence dimension.
215-
int32_t const* firstSparseMaskOffsetsKvPtr;
217+
int32_t* firstSparseMaskOffsetsKvPtr;
216218
// The counter for the multiCtasKv mode.
217219
int32_t* multiCtasKvCounterPtr;
218220
// The sequence length buffer for K/V.
@@ -240,6 +242,8 @@ struct TllmGenFmhaRunnerParams
240242
void* oPtr;
241243
// The output scaling factor buffer.
242244
void* oSfPtr;
245+
// The spec-decoding generation lengths.
246+
int const* spec_decoding_generation_lengths;
243247

244248
// Head dimension for Q and K.
245249
int mHeadDimQk;
@@ -284,6 +288,10 @@ struct TllmGenFmhaRunnerParams
284288
int mSparseMlaTopK;
285289
// The cuda stream.
286290
cudaStream_t stream;
291+
// The layer index.
292+
int32_t layer_idx = 0;
293+
// Whether the spec-dec tree is used.
294+
bool is_spec_dec_tree = false;
287295

288296
// set the attention mask type
289297
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType)

0 commit comments

Comments
 (0)