Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[CUDA] Output of SDPA should have same layout with inputs#2826

Merged
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:cuda-sdpa-layout
Nov 25, 2025
Merged

[CUDA] Output of SDPA should have same layout with inputs#2826
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:cuda-sdpa-layout

Conversation

@zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Nov 24, 2025

For following code:

import mlx.core as mx

B, T, N, D = (1, 128, 32, 64) 
scale = D ** -0.5

q = mx.random.normal((B, T, N, D), dtype=mx.float16).transpose(0, 2, 1, 3)
k = mx.random.normal((B, T, N, D), dtype=mx.float16).transpose(0, 2, 1, 3)
v = mx.random.normal((B, T, N, D), dtype=mx.float16).transpose(0, 2, 1, 3)

o = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale).transpose(0, 2, 1, 3)
check_row_contiguous = mx.flatten(o)
mx.eval(check_row_contiguous)

Before the change the graph is:

before

After the change the graph becomes:

after

i.e. the o becomes row contiguous.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge this. I left some questions cause I am still not sure..but I think what you have is better than what we had before so onward.

@zcbenz zcbenz merged commit f8bd675 into ml-explore:main Nov 25, 2025
10 checks passed
@zcbenz zcbenz deleted the cuda-sdpa-layout branch November 25, 2025 06:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants