Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 10417ac

Browse files
tlogncrcrpar
andauthored
fix groupnorm int32 index overflow (#1845)
* fix groupnorm int32 index overflow * add groupnorm test case --------- Co-authored-by: Masaki Kozuki <[email protected]>
1 parent c1fa083 commit 10417ac

4 files changed

Lines changed: 7 additions & 6 deletions

File tree

apex/contrib/csrc/group_norm/group_norm_nhwc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ struct Group_norm_nhwc_fwd_params {
111111
// The number of instances in the batch.
112112
int n;
113113
// The height and width of each activation map. The number of channels.
114-
int h, w, c, hw, hwc;
114+
int64_t h, w, c, hw, hwc;
115115
// The number of groups.
116116
int groups;
117117
// Do we apply the Swish activation function?
@@ -178,7 +178,7 @@ struct Group_norm_nhwc_bwd_params {
178178
// The number of instances in the batch.
179179
int n;
180180
// The height and width of each activation map. The number of channels.
181-
int h, w, c, hw, hwc;
181+
int64_t h, w, c, hw, hwc;
182182
// The number of groups.
183183
int groups;
184184
// Do we apply the Swish activation function?

apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params
7979
// The first activation loaded by that block.
8080
int hw_begin = blockIdx.y * params.acts_per_block;
8181
// The last activation loaded by that block.
82-
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
82+
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);
8383

8484
// The gradients for gamma/beta.
8585
float2 dgamma = make_float2(0.f, 0.f), dbeta = make_float2(0.f, 0.f);
@@ -364,7 +364,7 @@ __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params para
364364
// The first activation loaded by that block.
365365
int hw_begin = blockIdx.y * params.acts_per_block;
366366
// The last activation loaded by that block.
367-
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
367+
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);
368368

369369
// Iterate over the activations to compute the sums.
370370
for( int hwi = hw_begin; hwi < hw_end; ++hwi ) {

apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params
4343
// The first activation loaded by that block.
4444
int hw_begin = blockIdx.y * params.acts_per_block;
4545
// The last activation loaded by that block.
46-
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
46+
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);
4747

4848
// The sums.
4949
float sum = 0.f, sum_sq = 0.f;
@@ -273,7 +273,7 @@ __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params para
273273
// The first activation loaded by that block.
274274
int hw_begin = blockIdx.y * params.acts_per_block;
275275
// The last activation loaded by that block.
276-
int hw_end = min(hw_begin + params.acts_per_block, params.hw);
276+
int hw_end = min((int64_t) hw_begin + params.acts_per_block, params.hw);
277277

278278
// Iterate over the activations to compute the sums.
279279
for( int hwi = hw_begin; hwi < hw_end; ++hwi ) {

apex/contrib/test/group_norm/test_group_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def test_16_groups(self):
208208
[8, 1920, 32, 32],
209209
[8, 1920, 16, 16],
210210
[8, 2560, 8, 8],
211+
[1, 128, 16128, 1200],
211212
]
212213
for sz in sizes:
213214
n, c, h, w = sz

0 commit comments

Comments
 (0)