Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@avik-pal
Copy link
Member

@avik-pal avik-pal commented Aug 30, 2025

@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch from 646aa5e to bf2b3e6 Compare August 30, 2025 05:54
@avik-pal avik-pal mentioned this pull request Aug 30, 2025
3 tasks
@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch from bf2b3e6 to 7211f5f Compare August 30, 2025 05:56
@github-actions
Copy link
Contributor

github-actions bot commented Aug 30, 2025

Benchmark Results (ASV)

main c976d32... main / c976d32...
basics/overhead 0.127 ± 0.003 μs 0.131 ± 0.0041 μs 0.97 ± 0.038
time_to_load 0.877 ± 0.0096 s 0.865 ± 0.0019 s 1.01 ± 0.011

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch 3 times, most recently from 1c64d78 to 7280a95 Compare August 31, 2025 01:31
@avik-pal
Copy link
Member Author

Comparing against NNop.flash_attention

using NNop, CUDA, LinearAlgebra
using Reactant, Lux, LuxLib

E, L, H, B = 64, 4096, 4, 4
causal = true

q = randn(Float32, E, L, H, B);
k = randn(Float32, E, L, H, B);
v = randn(Float32, E, L, H, B);

q_cu = cu(q);
k_cu = cu(k);
v_cu = cu(v);

q_ra = Reactant.to_rarray(q);
k_ra = Reactant.to_rarray(k);
v_ra = Reactant.to_rarray(v);

function flash_attention(q, k, v; causal)
    res = NNop.flash_attention(q, k, v; causal)
    CUDA.synchronize()
    return res
end

res = flash_attention(q_cu, k_cu, v_cu; causal);

@time flash_attention(q_cu, k_cu, v_cu; causal); # 0.032220 seconds (130 allocations: 4.875 KiB)

function sdpa_fn(q, k, v; causal)
    mask = causal ? make_causal_mask(q) : nothing
    return scaled_dot_product_attention(
        q, k, v; head_dim=1, token_dim=2, mask=mask, bias=nothing
    )
end

@code_hlo sdpa_fn(q_ra, k_ra, v_ra; causal)

compiled_fn = @compile sync = true sdpa_fn(q_ra, k_ra, v_ra; causal)

@time compiled_fn(q_ra, k_ra, v_ra); # 0.025238 seconds (19 allocations: 592 bytes)

@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch 3 times, most recently from 88c82a1 to 80da2e3 Compare September 1, 2025 23:24
@avik-pal avik-pal changed the base branch from main to sdpa-luxlib-test1 September 1, 2025 23:25
@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch from 80da2e3 to 14cc8c7 Compare September 2, 2025 15:02
@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch from f4b5605 to 4d2df40 Compare September 2, 2025 17:47
Base automatically changed from sdpa-luxlib-test1 to main September 2, 2025 18:35
@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch from 4d2df40 to d0a27ca Compare September 2, 2025 18:59
@avik-pal avik-pal marked this pull request as ready for review September 2, 2025 18:59
feat: update Lux and Qwen example to use faster sdpa

fix: qwen3 model??

fix: transpose ordering

feat: expose kwargs in bmm + documentation

fix: repeat batch dim

feat: eliminate transpose in sdpa

feat: support bias in Reactant sdpa

fix: reduce permutations

fix: inv scale before

chore: cleanup

fix: zygote gradient for sdpa and batched_matmul

chore: add some comments

fix: throw a dimension mismatch error

feat: allow a pass through

feat: support gqa correctly

feat: use GQA directly in Qwen3 model

[Filtered to include only lib/LuxLib changes]
Original commit: 88c82a1

fix: correct error messages

feat: make is_causal a keyword argument

test: add tests for generalized bmm

fix: restrict LV uses with PermutedDimsArray

docs: add attention documentation

test: attention test

fix: reshape

test: attention test

test: attention test

fix: import

test: fix sdpa

feat: faster scaled_dot_product_attention for reactant

feat: update Lux and Qwen example to use faster sdpa

fix: qwen3 model??

fix: transpose ordering

feat: expose kwargs in bmm + documentation

fix: repeat batch dim

feat: eliminate transpose in sdpa

feat: support bias in Reactant sdpa

fix: reduce permutations

fix: inv scale before

chore: cleanup

fix: zygote gradient for sdpa and batched_matmul

chore: add some comments

fix: throw a dimension mismatch error

feat: allow a pass through

feat: support gqa correctly

feat: use GQA directly in Qwen3 model

feat: add is_causal to MHA layer

feat: use implicit causal masking
@avik-pal avik-pal force-pushed the ap/sdpa_improvements branch from d0a27ca to c976d32 Compare September 2, 2025 18:59
@avik-pal avik-pal merged commit 04daba4 into main Sep 2, 2025
38 of 42 checks passed
@avik-pal avik-pal deleted the ap/sdpa_improvements branch September 2, 2025 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants