Implement Unified Sparse Mask Strategy with Block-Level Skipping #164
+3,837
−10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements a comprehensive unified sparse mask strategy with block-level skipping for Flash Dynamic Mask Attention, addressing the need for memory-efficient sparse attention patterns beyond simple causal or fixed windows.
Key Features
🎯 Unified Mask Abstraction
Introduces a single interface supporting multiple sparse attention patterns:
⚡ Block-Level Computation Skipping
Implements unified OR-reduction based skip logic that operates at tile granularity:
💾 Memory Efficiency
Dramatic memory savings for long sequences:
Implementation Details
Core CUDA Components
UnifiedSparseMask
class with lightweight block descriptorsMaskFactory
utilities for creating different mask typesMask
struct for backward compatibilityapply_mask_with_skip_check()
Python API
Performance Benefits
Usage Example
Backward Compatibility
The implementation maintains full backward compatibility:
sparse_mask
parameter is optional in all APIsTesting & Validation
This implementation enables efficient processing of extremely long sequences (32K-128K tokens) with sparse attention patterns while maintaining the memory efficiency and performance characteristics of Flash Attention.
Fixes #163.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.