Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Sep 10, 2025

This PR implements optional attn_mask and attn_bias inputs with adaptive computation skipping to improve performance and reduce unnecessary memory operations in Flash Dynamic Mask Attention.

Problem

The current implementation always assumes both attn_mask and attn_bias are active, causing:

  • Unnecessary global memory loads when only one tensor is needed
  • Needless dbias computation when no bias is conceptually required
  • Inefficient workarounds like fake all-ones masks or zero bias tensors

Solution

Added support for 4 explicit modes with conditional processing:

Case attn_mask attn_bias Behavior
A None None Dense path, no block skip, no bias load/add (fastest)
B Tensor None Block skip using mask, no bias add/dbias
C None Tensor No block skip (all blocks active), add bias + compute dbias
D Tensor Tensor Current behavior (mask skip + bias add + dbias)

Key Changes

Python Interface

  • Both attn_mask and attn_bias parameters now accept Optional[Tensor] = None
  • Added use_mask and use_bias flags passed to CUDA kernels
  • Conditional gradient computation - dbias returned only when bias provided

CUDA Kernels

  • Modified mask application logic to conditionally process mask and bias based on runtime flags
  • Updated backward pass to skip dbias computation when use_bias=False
  • All changes preserve performance when both tensors are provided

Usage Example

from flash_dmattn import flash_dmattn_func_auto

flash_attn = flash_dmattn_func_auto()

# Case A: Dense attention (fastest for dense workloads)
out = flash_attn(q, k, v, attn_mask=None, attn_bias=None)

# Case B: Sparse attention with mask only  
out = flash_attn(q, k, v, attn_mask=sparse_mask, attn_bias=None)

# Case C: Dense attention with bias (e.g., ALiBi, relative position)
out = flash_attn(q, k, v, attn_mask=None, attn_bias=position_bias)

# Case D: Full functionality (unchanged)
out = flash_attn(q, k, v, attn_mask=sparse_mask, attn_bias=position_bias)

Performance Benefits

  • Case A: Eliminates both memory streams + skip logic overhead
  • Case B: Removes bias path (saves reads/writes + math)
  • Case C: Removes mask loads/OR reductions (simpler control flow)
  • Case D: Baseline performance (unchanged)

Backward Compatibility

The implementation is fully backward compatible - existing code continues to work unchanged. Default parameter values maintain current behavior when not specified.

Fixes #161.


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

@Copilot Copilot AI changed the title [WIP] [FEATURE REQUEST] Optional mask & bias inputs with adaptive computation skipping Add optional mask & bias inputs with adaptive computation skipping Sep 10, 2025
Copilot finished work on behalf of LoserCheems September 10, 2025 11:52
@Copilot Copilot AI requested a review from LoserCheems September 10, 2025 11:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE REQUEST] Optional mask & bias inputs with adaptive computation skipping
2 participants