Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dbe421e

Browse files
authored
Fix divide-by-zero in GroupNorm two-pass kernel for large batch sizes (#1984)
When batch size N is large enough (e.g., N=512 with C=640), the heuristic `blocks_per_act_slice = 256 / params.n` truncates to 0 via integer division, causing a subsequent `div_up(params.hw, blocks_per_act_slice)` to divide by zero. Fix by clamping blocks_per_act_slice to at least 1 in both forward and backward two-pass setup functions. Add regression test covering the exact repro case and all three heuristic branches. Signed-off-by: Tailing Yuan <[email protected]>
1 parent 212061e commit dbe421e

3 files changed

Lines changed: 41 additions & 0 deletions

File tree

apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params& params, si
204204
blocks_per_act_slice = 512 / params.n;
205205
}
206206

207+
// Clamp to at least 1 to avoid divide-by-zero when batch size is large.
208+
blocks_per_act_slice = max(blocks_per_act_slice, 1);
209+
207210
// Make sure we launch blocks per activation is no less than activations
208211
blocks_per_act_slice = min(blocks_per_act_slice, div_up(params.hw, params.n));
209212

apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params& params, si
126126
blocks_per_act_slice = 512 / params.n;
127127
}
128128

129+
// Clamp to at least 1 to avoid divide-by-zero when batch size is large.
130+
blocks_per_act_slice = max(blocks_per_act_slice, 1);
131+
129132
// Make sure we launch blocks per activation is no less than activations
130133
blocks_per_act_slice = min(blocks_per_act_slice, div_up(params.hw, params.n));
131134

apex/contrib/test/group_norm/test_group_norm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,41 @@ def test_16_groups(self):
280280
)
281281
self.verify_group_norm(GroupNorm, N=n, C=c, H=h, W=w, G=16, act="swish")
282282

283+
def test_large_batch_two_pass(self):
284+
"""Regression test for divide-by-zero when batch size is large.
285+
286+
When batch_size >= 256 and c >= 640, blocks_per_act_slice = 256 / n
287+
truncates to 0, causing div_up(hw, 0). Test all three heuristic branches.
288+
"""
289+
sizes = [
290+
[256, 1280, 8, 8],
291+
[512, 640, 16, 16],
292+
[1024, 512, 8, 8],
293+
]
294+
for sz in sizes:
295+
with self.subTest(size=sz):
296+
n, c, h, w = sz
297+
required = _estimate_group_norm_test_bytes(
298+
N=n,
299+
C=c,
300+
H=h,
301+
W=w,
302+
xdtype=torch.float16,
303+
wdtype=torch.float32,
304+
ref_func=torch_group_norm_high_precision_fp64,
305+
)
306+
if not _has_sufficient_cuda_memory(required):
307+
free_bytes, total_bytes = torch.cuda.mem_get_info()
308+
raise unittest.SkipTest(
309+
f"Skipping large-batch GroupNorm case {sz}: estimated "
310+
f"{required / 1e9:.1f} GB requires more than available "
311+
f"free VRAM ({free_bytes / 1e9:.1f} GB free, "
312+
f"{total_bytes / 1e9:.1f} GB total)."
313+
)
314+
self.verify_group_norm(
315+
cuda_group_norm_nhwc_two_pass, N=n, C=c, H=h, W=w, G=32, act="silu"
316+
)
317+
283318
def test_fp16_parameters(self):
284319
n, c, h, w = 8, 2560, 16, 16
285320
self.verify_group_norm(

0 commit comments

Comments
 (0)