Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Use lower-right causal mask alignment consistently#2967

Merged
awni merged 10 commits intoml-explore:mainfrom
Anri-Lombard:fix/sdpa-causal-mask-offset
Jan 29, 2026
Merged

Use lower-right causal mask alignment consistently#2967
awni merged 10 commits intoml-explore:mainfrom
Anri-Lombard:fix/sdpa-causal-mask-offset

Conversation

@Anri-Lombard
Copy link
Contributor

Summary

  • Document that MLX's mask="causal" uses lower-right alignment
  • Clarify the difference from PyTorch's default is_causal=True (upper-left)

When T_q != T_kv, this distinction matters:

  • MLX (lower-right): Last query aligns with last key
  • PyTorch default (upper-left): First query aligns with first key

References:

Relates to #2835

Clarify that MLX uses lower-right alignment for causal masks when
T_q != T_kv, which differs from PyTorch's default upper-left alignment.

Relates to ml-explore#2835
Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think PyTorch has a causal_lower_right option for SDPA and the description is not really right.

@Anri-Lombard
Copy link
Contributor Author

Hey @zcbenz, it does have causal_lower_right since 2.3 and can be used with SDPA via the attn_mask parameter. I ran a script with:

from torch.nn.attention.bias import causal_lower_right
bias = causal_lower_right(T_q, T_kv)
F.scaled_dot_product_attention(q, k, v, attn_mask=bias)

to verify.

Here is the tutorial that documents this explicitly: https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html.

I also verified masks are mathematically identical. For example with T_q=2, T_kv=4:

  MLX's mask (using q_off = max(0, kL - qL)):
        k0  k1  k2  k3
  q0 [  1   1   1   0  ]
  q1 [  1   1   1   1  ]

  PyTorch's causal_lower_right(2, 4):
        k0  k1  k2  k3
  q0 [  1   1   1   0  ]
  q1 [  1   1   1   1  ]

  PyTorch's is_causal=True (upper_left):
        k0  k1  k2  k3
  q0 [  1   0   0   0  ]
  q1 [  1   1   0   0  ]
  

The first two are identical; the third is different. This is also consistent with MLX's CUDA backend which uses cuDNN's set_causal_mask_bottom_right.

Is there something specific about the description you think is incorrect? if your concern is that causal_lower_right isn't a direct SDPA parameter (like is_causal=True) but rather a separate utility class, I could clarify the wording to use the full module path torch.nn.attention.bias.causal_lower_right.

@zcbenz
Copy link
Collaborator

zcbenz commented Jan 18, 2026

Thanks for linking the docs, this is a new learn for me. On the behavior, it actually depends on whether T_q is larger or smaller than T_kv:

if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}

The mask uses lower-right alignment when T_q <= T_kv and upper-left when T_q > T_kv.
@Anri-Lombard
Copy link
Contributor Author

Thanks! Fixed to describe the conditional alignment behavior 🙏

Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. /cc @awni for a second look.

@awni
Copy link
Member

awni commented Jan 21, 2026

The comment definitely makes sense. But I also find it a bit strange that we switch from lower right to upper left depending on if query is longer or shorter than the keys. It's quite rare for the query to be longer than the keys which is why we never really looked at it carefully.

I'm wondering if we should change the behavior in that case rather than documenting something that is a bit unusual? Or maybe it's a good idea to keep it this way?

@zcbenz
Copy link
Collaborator

zcbenz commented Jan 21, 2026

I agree current behavior is unusual, and using lower right for all should be a better choice.

@awni
Copy link
Member

awni commented Jan 21, 2026

@Anri-Lombard what do you think about changing the behavior to always be lower right even when QL > KL? Do you want to send a patch to this PR / send a new one instead of this?

@Anri-Lombard
Copy link
Contributor Author

Hey @awni, always lower-right makes sense. The change is minimal (unless I'm missing somethign) - just two cuDNN locations (forward/backward) and the CPU fallback offset calculation. I'll update this PR to make the behavior change instead of just documenting it 👍

@awni
Copy link
Member

awni commented Jan 22, 2026

Yes the change should be pretty straight-forward. We may also need to update the mask index calculation in the Metal kernels. If you add a test for this case as well (qL > kL) that would be great. I can help with the metal kernels if needed.

- cuDNN: Always use set_causal_mask_bottom_right() instead of conditionally
  selecting based on qL vs kL. This aligns with FlashAttention/PyTorch behavior.
- Steel kernels: Add NaN protection for sum_score == 0 edge case when all
  keys are masked.
Enable scaled_dot_product_attention to handle cases where query sequence
is longer than key sequence with causal mask. When qL > kL, early queries
have no keys to attend to and output zeros.

Changes:
- Remove Metal routing guard that blocked qL > kL for causal mask
- Fix CPU fallback to use proper lower-right alignment (not clamped)
- Zero out output rows where queries have no keys to attend (row_pos < 0)
- Update test references to handle all-masked rows correctly
@Anri-Lombard Anri-Lombard changed the title Document causal mask alignment in scaled_dot_product_attention Use lower-right causal mask alignment consistently Jan 23, 2026
@Anri-Lombard
Copy link
Contributor Author

Anri-Lombard commented Jan 23, 2026

@awni and @zcbenz updated and took a stab at the Metal kernels as well - feel free to push changes directly or point out where I deviated if you don't mind the extra time so I can learn the convention preferences more 🙏

For qL > kL, early queries have no keys to attend. Softmax of all-masked values gives uniform weights (exp(finite_min - finite_min) = 1), not zeros. Following PyTorch's pytorch/pytorch#108108 convention, we explicitly zero these rows... I think this is the only "big" change.

@Anri-Lombard
Copy link
Contributor Author

@awni you mentioned the tests, the existing test shapes (127, 65, ...) with mask="causal" cover the qL > kL case. Would you prefer an explicit test that verifies early queries output zeros? 🙏

@awni
Copy link
Member

awni commented Jan 23, 2026

Nope if it's already tested that is fine!

@awni
Copy link
Member

awni commented Jan 23, 2026

I don't think we should ensure 0s in the qL > kL case. It's a problem I've looked at in the past is what to do if every key position for a given query is masked. And right now it's not consistent. For now let's leave it as undefined behavior and then look into a more principled fix if necessary. I also would rather not reduce performance overall to handle an edge case we don't really care much about).

Per review feedback, leave qL > kL with causal mask as undefined
behavior rather than ensuring zeros. This avoids performance overhead
for an edge case. Tests skip this undefined case.
@Anri-Lombard
Copy link
Contributor Author

@awni done - removed the zero-row handling. The qL > kL + causal case is now undefined behavior as suggested. Tests skip that case.

Comment on lines +588 to +590
# Skip causal tests when qL > kL (undefined behavior)
if mask_str == "causal" and qL > kL:
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than skipping this test, could you add a little step after the computation which checks that the parts which should match do match?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically slice off the initial qL-kL from the result if it's greater than 0 and then compare.

Copy link
Contributor Author

@Anri-Lombard Anri-Lombard Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm struggling to implement this, I'm getting a bug where the fast path is outputting garbage since the fallback is not being applied when it should. I'll need some time to figure this out 🙏

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could revert this change for now which would likely fix it.

  const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
      (query_sequence_length <= key_sequence_length && do_causal); # <- add that back

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it won't dispatch to the fused implementation for that case.. but that's ok for now. And if you want to fix it in a follow on that would be great!

Copy link
Contributor Author

@Anri-Lombard Anri-Lombard Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! All tests now pass locally and tests slice undefined rows instead of skipping 🙏 Happy to do a follow up pr for a proper fix when gL > kL 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I will merge it when the CI tests clear.

Anri-Lombard and others added 3 commits January 26, 2026 17:43
Co-authored-by: Awni Hannun <[email protected]>
Per awni's feedback, revert the Metal backend condition to require
query_sequence_length <= key_sequence_length for causal mask. This
prevents dispatching to the fused kernel for the qL > kL case.

The test now slices off the first qL-kL rows (undefined behavior region)
before comparison instead of skipping these cases entirely.
When transpose=True, output shape is (B, qL, qH, D) with sequence
dimension at index 1. The previous fix was slicing dimension 2 for
both cases, causing test failures. Now correctly slices dimension 1
for transpose=True and dimension 2 for transpose=False.
@awni awni merged commit 0c6a895 into ml-explore:main Jan 29, 2026
16 checks passed
@zcbenz
Copy link
Collaborator

zcbenz commented Jan 30, 2026

It turns out cuDNN does not like this configuration ☹️ :

======================================================================
ERROR: test_sdpa (test_fast_sdpa.TestSDPA.test_sdpa) (B=1, qsl=127, ksl=65, head_dim=64, n_q_heads=32, n_kv_heads=8, mask='causal', transpose=False, dtype='float16')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\cygwin64\home\cheng\codes\mlx\python\tests\test_fast_sdpa.py", line 637, in test_sdpa
    self.assertLessEqual(mx.max(diff).item(), atol)
                         ~~~~~~~~~~~~~~~~~^^
RuntimeError: graph.prepare() failed: Bottom right causal mask does not support max_s_q > max_s_kv. Please virtually slice the Q tensor and pass it as max_s_q == max_s_kv.

(The test does not run in CI as the hardware is not supported by cuDNN)

I'm going to disable cuDNN SDPA for T_q > T_kv with mask='causal' for now.

@Anri-Lombard
Copy link
Contributor Author

Dang! Sorry to see this @zcbenz 🙏 I can have a look later to see how we could make cuDNN happy, but disabling it for now makes sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants