Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions samples/AttentionFMHA.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
def fmha_kernel(Q, K, V, Out,
qk_scale: float,
input_pos: int,
TILE_D: ConstInt, # TILE_D = hidden_size
Dqk: ConstInt, # Head dimension of Q and K
Dv: ConstInt, # Head dimension of V
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
Expand Down Expand Up @@ -64,12 +65,12 @@ def fmha_kernel(Q, K, V, Out,
# Initialize online softmax accumulators in float32 for stability
m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
acc = ct.full((TILE_M, Dv), 0.0, dtype=np.float32)

# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, Dqk)
).reshape((TILE_M, Dqk)) # [TILE_M, Dqk]

# loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
Expand All @@ -88,11 +89,11 @@ def fmha_kernel(Q, K, V, Out,
for j in range(0, Tc):
# --- Compute QK product ---
k = ct.load(
K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, Dqk, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
k = k.reshape((Dqk, TILE_N)) # [Dqk, TILE_N]
qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]

Expand Down Expand Up @@ -125,16 +126,16 @@ def fmha_kernel(Q, K, V, Out,

# --- Compute PV product ---
v = ct.load(
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, Dv),
latency=4,
).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
).reshape((TILE_N, Dv)) # [TILE_N, Dv]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]

# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
acc = acc.reshape((1, 1, TILE_M, Dv)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)


Expand Down Expand Up @@ -202,6 +203,7 @@ def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
qk_scale,
input_pos,
D_k,
D_v,
Heads,
tile_m,
tile_n,
Expand Down Expand Up @@ -273,12 +275,18 @@ def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,


def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
is_causal: bool, enable_gqa: bool) -> torch.Tensor:
backend = SDPBackend.CUDNN_ATTENTION \
if (Q.shape[2] == K.shape[2]) \
else SDPBackend.FLASH_ATTENTION
with sdpa_kernel(backend):
ret = scaled_dot_product_attention(Q, K, V,
is_causal: bool, enable_gqa: bool,
use_backend_selection_rule: bool = False) -> torch.Tensor:
if use_backend_selection_rule:
backend = SDPBackend.CUDNN_ATTENTION \
if (Q.shape[2] == K.shape[2]) \
else SDPBackend.FLASH_ATTENTION
with sdpa_kernel(backend):
ret = scaled_dot_product_attention(Q, K, V,
is_causal=is_causal,
enable_gqa=enable_gqa)
else:
ret = scaled_dot_product_attention(Q, K, V,
is_causal=is_causal,
enable_gqa=enable_gqa)
return ret
Expand All @@ -296,13 +304,14 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,

# --- User Configuration ---
BATCH_SIZE = 2
NUM_HEADS = 8
NUM_HEADS = 32
SEQ_LEN_Q = 128
SEQ_LEN_KV = 128
D_K = 64
SEQ_LEN_KV = 256
D_K = 128
D_V = 64

QUERY_GROUP_SIZE = 1
QUERY_GROUP_SIZE = 8
enable_gqa = QUERY_GROUP_SIZE > 1

DTYPE = torch.float16

Expand Down Expand Up @@ -336,7 +345,7 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
dtype:{output_fmha_cutile_non_causal.dtype}""")
if args.correctness_check:
ref_fmha = torch_fmha(Q_input, K_input, V_input,
is_causal=False, enable_gqa=False)
is_causal=False, enable_gqa=enable_gqa)
torch.testing.assert_close(output_fmha_cutile_non_causal, ref_fmha, atol=1e-3, rtol=1e-3)
print("Correctness check passed")
else:
Expand All @@ -354,7 +363,7 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
dtype: {output_fmha_cutile_causal.dtype}""")
if args.correctness_check:
ref_fmha = torch_fmha(Q_input, K_input, V_input,
is_causal=True, enable_gqa=False)
is_causal=True, enable_gqa=enable_gqa)
torch.testing.assert_close(output_fmha_cutile_causal, ref_fmha, atol=1e-3, rtol=1e-3)
print("Correctness check passed")
else:
Expand Down Expand Up @@ -394,7 +403,7 @@ def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
dtype: {output_fmha_cutile_autotune_causal.dtype}""")
print(f"Tuned config: {tuned_config}")
if args.correctness_check:
ref_fmha = torch_fmha(Q_input, K_input, V_input, is_causal=True, enable_gqa=False)
ref_fmha = torch_fmha(Q_input, K_input, V_input, is_causal=True, enable_gqa=enable_gqa)
torch.testing.assert_close(
output_fmha_cutile_autotune_causal, ref_fmha, atol=1e-2, rtol=5e-2
)
Expand Down
1 change: 1 addition & 0 deletions samples/templates/AttentionFMHA.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
qk_scale,
input_pos,
D_k,
D_v,
Heads,
tile_m,
tile_n,
Expand Down
16 changes: 9 additions & 7 deletions test/bench_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark):
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
)

B, H, L, D = q.shape
# first gemm mma(q, k): 2 * B * H * L * L * D
# second gemm mma(p, v): 2 * B * H * L * L * D
flop_count = 4 * B * H * L * L * D
B, H, L, Dqk = q.shape
_, _, _, Dv = v.shape
# first gemm mma(q, k): 2 * B * H * L * L * Dqk
# second gemm mma(p, v): 2 * B * H * L * L * Dv
flop_count = 2 * B * H * L * L * (Dqk + Dv)

if is_causal:
flop_count /= 2
Expand All @@ -88,9 +89,10 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark):


def cutile_fmha(q, k, v, o, is_causal, enable_gqa):
b, qh, q_len, d = q.shape
b, qh, q_len, dqk = q.shape
_, kh, k_len, _ = k.shape
qk_scale = 1 / sqrt(d)
_, _, _, dv = v.shape
qk_scale = 1 / sqrt(dqk)
TILE_M, TILE_N = (256, 128) if is_causal else (64, 128)
query_group_size = qh // kh
grid = (ceil(q_len / TILE_M), b * qh, 1)
Expand All @@ -100,7 +102,7 @@ def cutile_fmha(q, k, v, o, is_causal, enable_gqa):
(q, k, v, o,
qk_scale,
input_pos,
d, qh,
dqk, dv, qh,
TILE_M, TILE_N,
query_group_size, is_causal, EVEN_K))

Expand Down
19 changes: 10 additions & 9 deletions test/kernels/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
def fmha_kernel(Q, K, V, Out,
qk_scale: float,
input_pos: int,
TILE_D: ConstInt, # TILE_D = hidden_size
Dqk: ConstInt, # Head dimension of Q and K
Dv: ConstInt, # Head dimension of V
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
Expand Down Expand Up @@ -54,12 +55,12 @@ def fmha_kernel(Q, K, V, Out,
# Initialize online softmax accumulators in float32 for stability
m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
acc = ct.full((TILE_M, Dv), 0.0, dtype=np.float32)

# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, Dqk)
).reshape((TILE_M, Dqk)) # [TILE_M, Dqk]

# loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
Expand All @@ -78,11 +79,11 @@ def fmha_kernel(Q, K, V, Out,
for j in range(0, Tc):
# --- Compute QK product ---
k = ct.load(
K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, Dqk, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
k = k.reshape((Dqk, TILE_N)) # [Dqk, TILE_N]
qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]

Expand Down Expand Up @@ -115,14 +116,14 @@ def fmha_kernel(Q, K, V, Out,

# --- Compute PV product ---
v = ct.load(
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, Dv),
latency=4,
).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
).reshape((TILE_N, Dv)) # [TILE_N, Dv]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]

# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
acc = acc.reshape((1, 1, TILE_M, Dv)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
14 changes: 7 additions & 7 deletions test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
@pytest.mark.parametrize("k_heads", [8])
@pytest.mark.parametrize("q_len", [1, 15, 32])
@pytest.mark.parametrize("k_len", [32, 63])
@pytest.mark.parametrize("hidden_size", [32])
@pytest.mark.parametrize("head_dim", [32])
@pytest.mark.parametrize("tile_size", [(8, 16)])
@pytest.mark.parametrize("is_causal", [True, False])
@pytest.mark.parametrize("use_input_pos", [True, False])
def test_flash_attention(batch_size, q_heads, k_heads,
q_len, k_len,
hidden_size, tile_size, is_causal,
head_dim, tile_size, is_causal,
use_input_pos,
float_dtype):
query_group_size = q_heads // k_heads
TILE_M, TILE_N = tile_size
qk_scale = 1 / math.sqrt(hidden_size)
q = torch.randn((batch_size, q_heads, q_len, hidden_size), dtype=float_dtype, device='cuda')
k = torch.randn((batch_size, k_heads, k_len, hidden_size), dtype=float_dtype, device='cuda')
v = torch.randn((batch_size, k_heads, k_len, hidden_size), dtype=float_dtype, device='cuda')
qk_scale = 1 / math.sqrt(head_dim)
q = torch.randn((batch_size, q_heads, q_len, head_dim), dtype=float_dtype, device='cuda')
k = torch.randn((batch_size, k_heads, k_len, head_dim), dtype=float_dtype, device='cuda')
v = torch.randn((batch_size, k_heads, k_len, head_dim), dtype=float_dtype, device='cuda')
o = torch.zeros_like(q)
grid = (math.ceil(q_len / TILE_M), batch_size * q_heads, 1)
if use_input_pos:
Expand All @@ -43,7 +43,7 @@ def test_flash_attention(batch_size, q_heads, k_heads,
(q, k, v, o,
qk_scale,
input_pos,
hidden_size, q_heads,
head_dim, head_dim, q_heads,
TILE_M, TILE_N,
query_group_size, is_causal, EVEN_K))
if is_causal:
Expand Down
Loading