Fix INF issue in bf16 backward pass with safer value clamping #181
+333
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
When using bf16 data type with large sequence lengths (e.g., seq_len=4096, window=2048), the backward pass would fail with INF errors during the first training step:
The issue occurred at line 91 in
modeling_flash_dynamic_mask_attention_utils.py
where extreme negative values (torch.finfo(bfloat16).min = -3.39e+38
) used for attention masking caused numerical overflow during fp32→bf16 conversion in the CUDA backward kernel.Root Cause
torch.finfo(dtype).min
for bf16 produces values at the edge of representable rangeconvert_type<Element>(acc_dp)
, extreme intermediate computations could exceed bf16's representable rangecutlass::NumericArrayConverter
could produce INF values when handling edge-case floating point valuesSolution
This PR implements a two-level fix:
1. CUDA Kernel Level (
utils.h
,flash_bwd_kernel.h
)convert_type_safe()
function with proper value clamping before conversion2. Python Interface Level (
modeling_flash_dynamic_mask_attention_utils.py
)-1e30
(instead of-3.39e+38
)-1e4
(instead of-65504
)Testing & Validation
validate_bf16_fix.py
for users to verify the fixCompatibility
Users experiencing the bf16 INF issue can validate the fix using:
Fixes issue described in #XXX where training with bf16 + large sequence lengths would fail with INF gradients.
Original prompt
This section details on the original issue you should resolve
<issue_title>[BUG REPORT] INF occurs in backward phrase of the first training step</issue_title>
<issue_description>Describe the bug
Using code version v1.1.9, the training configuration is set to seq_len=4096 and window=2048. During the first training step, the backward pass consistently fails with an INF (infinity) error. Preliminary investigation has pinpointed the location in the code where the error occurs https://github.com/SmallDoges/flash-dmattn/blob/main/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py#L91. The root cause appears to be that some numerical values are outside the representable range of the bf16 data type.
To Reproduce
Steps to reproduce the behavior:
# Paste your code here
Expected behavior
No INF values.
Environment Information
PyTorch: 2.6.0+cu124
CUDA: 12.4
GPU: NVIDIA A800-SXM4-80GB
Additional context
Error traceback
Debugging Information
Comments on the Issue (you are @copilot in this section)
@LoserCheemsFollowing is an analysis done using GPT-5-CodeX regarding the safety of writing bias gradients in the kernel. Next, I will thoroughly check whether the current codebase guarantees safety under all circumstances. Once confirmed, we should analyze whether the caller passes unsafe parameters to the kernel, which might cause NAN issues.
1. Definitions and Goal
We consider one backward iteration over a (query block, key block).
Let (before masking)
Mask and causal logic zero out (logically remove) invalid positions by setting those score entries to -INF (or a very large negative sentinel) before the softmax. In the backward kernel we reconstruct local softmax probabilities P for the block and compute dS = ∂L/∂raw_ij (which also equals ∂L/∂bias_ij when bias exists). We need to show: at the point labeled “Write dS to dBias” there is no -inf in dS, so writing it directly to global memory is safe.
2. Where -inf Appears (and Disappears)
apply_mask
):Masked or causal‑disallowed positions in
scores
are set to -INF (or a large negative). This is still in fp32 accumulator space (tensorscores
built fromacc_s
).scale_apply_exp2</*scale_max=*/false>
):Immediately after masking, the code exponentiates those scores into probabilities.
exp(-inf) = 0 ⇒ all masked / OOB positions become exactly 0 probability.
From now on
scores
holds probabilities P (not raw logits), all finite and non‑negative; masked entries are 0, never -inf.3. Construction of dS in the Kernel
Relevant excerpt (simplified):
Interpretation:
Before the loop,
dS(mi,ni)
temporarily names dp (an element of ∂L/∂P).After the loop, each
dS(mi,ni)
becomes:For masked entries: P(mi,ni) = 0, so
dS(mi,ni) = 0
regardless of the other finite terms.All operands (P, dp, dP_sum, optional dtanh) are finite fp32 numbers under normal (non-overflow) regimes; no path produces -inf.
Even in edge cases:
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.