Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4fd4f30

Browse files
authored
Update group norm license. (#1935)
1 parent 214c6b3 commit 4fd4f30

40 files changed

Lines changed: 277 additions & 759 deletions

apex/contrib/csrc/group_norm/group_norm_nhwc.cpp

Lines changed: 78 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
1-
/***************************************************************************************************
2-
* Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved.
3-
*
4-
* Redistribution and use in source and binary forms, with or without modification, are not permit-
5-
* ted.
6-
*
7-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
8-
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
9-
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
10-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
11-
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
12-
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
13-
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
14-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15-
*
16-
**************************************************************************************************/
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*/
175
#include <traits.h>
186
#include <group_norm_nhwc.h>
197
#include <group_norm_nhwc_fwd_one_pass.h>
@@ -48,9 +36,9 @@ float inline unpack(const float& x) {
4836

4937
template <typename T>
5038
void check_results(const char *name,
51-
const T *out,
52-
const T *ref,
53-
size_t elts,
39+
const T *out,
40+
const T *ref,
41+
size_t elts,
5442
float tol) {
5543

5644
// The number of errors.
@@ -76,14 +64,14 @@ void check_results(const char *name,
7664
float abs_b = fabsf(b);
7765

7866
// Compute the error.
79-
float den = abs_a + abs_b;
67+
float den = abs_a + abs_b;
8068
// Is one of the quantities very small?
8169
bool is_small = abs_a <= tol || abs_b <= tol || den <= tol;
8270
// The error.
8371
float err = is_small ? fabsf(a-b) : fabsf(a-b) / den;
8472
// Is the result ok?
8573
bool ok = !isnan(a) && !isnan(b) && err <= tol;
86-
74+
8775
// Print the error.
8876
if( !ok && (failed < 10 || err > max_err) ) {
8977

@@ -146,19 +134,19 @@ template void check_results(const char *name, const float *out, const float *ref
146134

147135
////////////////////////////////////////////////////////////////////////////////////////////////////
148136

149-
static void group_norm_nhwc_bwd_(void *dx_h,
137+
static void group_norm_nhwc_bwd_(void *dx_h,
150138
float *dgamma_h,
151139
float *dbeta_h,
152-
const void *dy_h,
153-
const void *x_h,
140+
const void *dy_h,
141+
const void *x_h,
154142
const float *gamma_h,
155143
const float *beta_h,
156-
const float2 *sums_h,
144+
const float2 *sums_h,
157145
float epsilon,
158-
int n,
159-
int h,
160-
int w,
161-
int c,
146+
int n,
147+
int h,
148+
int w,
149+
int c,
162150
int groups,
163151
bool with_swish,
164152
bool use_fp32,
@@ -259,7 +247,7 @@ static void group_norm_nhwc_bwd_(void *dx_h,
259247
} // ii
260248
} // wi
261249
} // hi
262-
250+
263251
mean_1 *= rcp_hwc_per_group;
264252
mean_2 *= rcp_hwc_per_group;
265253

@@ -342,15 +330,15 @@ static void group_norm_nhwc_bwd_(void *dx_h,
342330

343331
////////////////////////////////////////////////////////////////////////////////////////////////////
344332

345-
static void group_norm_nhwc_fwd_(void *y_h,
346-
const void *x_h,
347-
const float *gamma_h,
348-
const float *beta_h,
333+
static void group_norm_nhwc_fwd_(void *y_h,
334+
const void *x_h,
335+
const float *gamma_h,
336+
const float *beta_h,
349337
float epsilon,
350-
int n,
351-
int h,
352-
int w,
353-
int c,
338+
int n,
339+
int h,
340+
int w,
341+
int c,
354342
int groups,
355343
bool with_swish,
356344
bool use_fp32,
@@ -602,7 +590,7 @@ int main(int argc, char **argv) {
602590
printf("mode.........................: bwd\n");
603591
} else if( mode == Mode::FWD_INFERENCE ) {
604592
printf("mode.........................: fwd inference\n");
605-
} else if( mode == Mode::FWD_TRAINING ) {
593+
} else if( mode == Mode::FWD_TRAINING ) {
606594
printf("mode.........................: fwd training\n");
607595
} else {
608596
assert(false);
@@ -672,7 +660,7 @@ int main(int argc, char **argv) {
672660
}
673661

674662
// Allocate the src/dst on the host for the gradients (bwd).
675-
void *dx_h = nullptr, *dy_h = nullptr;
663+
void *dx_h = nullptr, *dy_h = nullptr;
676664
if( mode == Mode::BWD ) {
677665
dx_h = malloc(x_sz);
678666
dy_h = malloc(x_sz);
@@ -798,20 +786,20 @@ int main(int argc, char **argv) {
798786

799787
// Compute the golden reference on the host.
800788
if (!skip_checks) {
801-
if( mode == Mode::BWD ) {
802-
group_norm_nhwc_bwd_(dx_ref_h,
789+
if( mode == Mode::BWD ) {
790+
group_norm_nhwc_bwd_(dx_ref_h,
803791
dgamma_ref_h,
804792
dbeta_ref_h,
805793
dy_h,
806-
x_h,
807-
gamma_h,
794+
x_h,
795+
gamma_h,
808796
beta_h,
809797
sums_h,
810-
epsilon,
811-
n,
812-
h,
813-
w,
814-
c,
798+
epsilon,
799+
n,
800+
h,
801+
w,
802+
c,
815803
groups,
816804
with_swish,
817805
use_fp32,
@@ -823,32 +811,32 @@ int main(int argc, char **argv) {
823811

824812
// Copy to the device.
825813
CHECK_CUDA(cudaMemcpyAsync(x_d, x_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault));
826-
CHECK_CUDA(cudaMemcpyAsync(gamma_d,
827-
gamma_h,
828-
gamma_sz,
829-
cudaMemcpyHostToDevice,
814+
CHECK_CUDA(cudaMemcpyAsync(gamma_d,
815+
gamma_h,
816+
gamma_sz,
817+
cudaMemcpyHostToDevice,
830818
cudaStreamDefault));
831-
CHECK_CUDA(cudaMemcpyAsync(beta_d,
832-
beta_h,
833-
gamma_sz,
834-
cudaMemcpyHostToDevice,
819+
CHECK_CUDA(cudaMemcpyAsync(beta_d,
820+
beta_h,
821+
gamma_sz,
822+
cudaMemcpyHostToDevice,
835823
cudaStreamDefault));
836824

837825
if( mode == Mode::BWD ) {
838-
CHECK_CUDA(cudaMemcpyAsync(dy_d,
839-
dy_h,
840-
x_sz,
841-
cudaMemcpyHostToDevice,
826+
CHECK_CUDA(cudaMemcpyAsync(dy_d,
827+
dy_h,
828+
x_sz,
829+
cudaMemcpyHostToDevice,
842830
cudaStreamDefault));
843831

844832
// // DEBUG.
845833
// printf("sums_h[0] = %8.3f, %8.3f\n", sums_h[0].x, sums_h[0].y);
846834
// // END OF DEBUG.
847835

848-
CHECK_CUDA(cudaMemcpyAsync(sums_d,
849-
sums_h,
850-
sums_sz,
851-
cudaMemcpyHostToDevice,
836+
CHECK_CUDA(cudaMemcpyAsync(sums_d,
837+
sums_h,
838+
sums_sz,
839+
cudaMemcpyHostToDevice,
852840
cudaStreamDefault));
853841
}
854842

@@ -878,7 +866,7 @@ int main(int argc, char **argv) {
878866
}();
879867

880868
// Initialize the parameters.
881-
if( mode == Mode::BWD ) {
869+
if( mode == Mode::BWD ) {
882870
params_bwd.dx = dx_d;
883871
params_bwd.dgamma = dgamma_d;
884872
params_bwd.dbeta = dbeta_d;
@@ -914,30 +902,30 @@ int main(int argc, char **argv) {
914902
// The number of barriers.
915903
size_t barriers_elts = 0;
916904
// The number of elements in the reduction buffer.
917-
size_t red_buffer_elts = 0;
905+
size_t red_buffer_elts = 0;
918906
// The number of elements in the reduction buffer that must be zeroed.
919-
size_t zeroed_red_buffer_elts = 0;
907+
size_t zeroed_red_buffer_elts = 0;
920908

921909
// Finalize the parameters.
922910
dim3 grid;
923911
if( mode == Mode::BWD && use_one_pass ) {
924-
group_norm_nhwc_bwd_one_pass_setup(params_bwd,
925-
barriers_elts,
926-
red_buffer_elts,
912+
group_norm_nhwc_bwd_one_pass_setup(params_bwd,
913+
barriers_elts,
914+
red_buffer_elts,
927915
zeroed_red_buffer_elts,
928-
grid,
916+
grid,
929917
props);
930918
} else if( mode == Mode::BWD ) {
931-
group_norm_nhwc_bwd_two_passes_setup(params_bwd,
919+
group_norm_nhwc_bwd_two_passes_setup(params_bwd,
932920
zeroed_red_buffer_elts);
933921
} else if( use_one_pass ) {
934-
group_norm_nhwc_fwd_one_pass_setup(params_fwd,
922+
group_norm_nhwc_fwd_one_pass_setup(params_fwd,
935923
barriers_elts,
936-
red_buffer_elts,
937-
grid,
924+
red_buffer_elts,
925+
grid,
938926
props);
939927
} else {
940-
group_norm_nhwc_fwd_two_passes_setup(params_fwd,
928+
group_norm_nhwc_fwd_two_passes_setup(params_fwd,
941929
zeroed_red_buffer_elts);
942930
}
943931

@@ -987,9 +975,9 @@ int main(int argc, char **argv) {
987975

988976
// Clear the zeroed buffer if needed.
989977
if( zeroed_red_buffer_sz > 0 ) {
990-
CHECK_CUDA(cudaMemsetAsync(zeroed_red_buffer_d_,
991-
0,
992-
zeroed_red_buffer_sz,
978+
CHECK_CUDA(cudaMemsetAsync(zeroed_red_buffer_d_,
979+
0,
980+
zeroed_red_buffer_sz,
993981
cudaStreamDefault));
994982
}
995983
if( use_one_pass && mode == Mode::BWD ) {
@@ -1020,15 +1008,15 @@ int main(int argc, char **argv) {
10201008
// Copy the results to the host.
10211009
if( mode == Mode::BWD ) {
10221010
CHECK_CUDA(cudaMemcpyAsync(dx_h, dx_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));
1023-
CHECK_CUDA(cudaMemcpyAsync(dgamma_h,
1024-
dgamma_d,
1025-
gamma_sz,
1026-
cudaMemcpyDeviceToHost,
1011+
CHECK_CUDA(cudaMemcpyAsync(dgamma_h,
1012+
dgamma_d,
1013+
gamma_sz,
1014+
cudaMemcpyDeviceToHost,
10271015
cudaStreamDefault));
1028-
CHECK_CUDA(cudaMemcpyAsync(dbeta_h,
1029-
dbeta_d,
1030-
gamma_sz,
1031-
cudaMemcpyDeviceToHost,
1016+
CHECK_CUDA(cudaMemcpyAsync(dbeta_h,
1017+
dbeta_d,
1018+
gamma_sz,
1019+
cudaMemcpyDeviceToHost,
10321020
cudaStreamDefault));
10331021
} else {
10341022
CHECK_CUDA(cudaMemcpyAsync(y_h, y_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));
@@ -1041,7 +1029,7 @@ int main(int argc, char **argv) {
10411029
if (!csv_output) {
10421030
if( mode == Mode::BWD && !skip_checks ) {
10431031
if (use_fp32) {
1044-
check_results<float>("dx", reinterpret_cast<float*>(dx_h),
1032+
check_results<float>("dx", reinterpret_cast<float*>(dx_h),
10451033
reinterpret_cast<float*>(dx_ref_h), x_elts, tol);
10461034
} else if (use_bf16) {
10471035
check_results<__nv_bfloat16>("dx", reinterpret_cast<__nv_bfloat16*>(dx_h),
@@ -1054,7 +1042,7 @@ int main(int argc, char **argv) {
10541042
check_results<float> ("dbeta", dbeta_h, dbeta_ref_h, gamma_elts, tol);
10551043
} else if( !skip_checks ) {
10561044
if (use_fp32) {
1057-
check_results<float>("y", reinterpret_cast<float*>(y_h),
1045+
check_results<float>("y", reinterpret_cast<float*>(y_h),
10581046
reinterpret_cast<float*>(y_ref_h), x_elts, tol);
10591047
} else if (use_bf16) {
10601048
check_results<__nv_bfloat16>("y", reinterpret_cast<__nv_bfloat16*>(y_h),
@@ -1107,4 +1095,3 @@ int main(int argc, char **argv) {
11071095
}
11081096

11091097
////////////////////////////////////////////////////////////////////////////////////////////////////
1110-

apex/contrib/csrc/group_norm/group_norm_nhwc.h

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
1-
/***************************************************************************************************
2-
* Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved.
3-
*
4-
* Redistribution and use in source and binary forms, with or without modification, are not permit-
5-
* ted.
6-
*
7-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
8-
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
9-
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
10-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
11-
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
12-
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
13-
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
14-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15-
*
16-
**************************************************************************************************/
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*/
175
#pragma once
186

197
#include <math.h>
@@ -138,7 +126,7 @@ struct Group_norm_nhwc_fwd_params {
138126
// The number of groups in each block.
139127
int groups_per_block;
140128
// The number of channels per group = c / groups.
141-
int channels_per_group;
129+
int channels_per_group;
142130
// The number of channels per block = groups_per_block * channels_per_group.
143131
int channels_per_block;
144132
// The inverse of hwc in floats (to compute mean/var).
@@ -149,7 +137,7 @@ struct Group_norm_nhwc_fwd_params {
149137

150138
////////////////////////////////////////////////////////////////////////////////////////////////////
151139

152-
void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&,
140+
void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&,
153141
size_t &red_buffer_elts);
154142

155143
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -205,7 +193,7 @@ struct Group_norm_nhwc_bwd_params {
205193
// The number of groups in each block.
206194
int groups_per_block;
207195
// The number of channels per group = c / groups.
208-
int channels_per_group;
196+
int channels_per_group;
209197
// The number of channels per block = groups_per_block * channels_per_group.
210198
int channels_per_block;
211199
// The inverse of hwc in floats (to compute mean/var).
@@ -216,7 +204,7 @@ struct Group_norm_nhwc_bwd_params {
216204

217205
////////////////////////////////////////////////////////////////////////////////////////////////////
218206

219-
void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params&,
207+
void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params&,
220208
size_t &red_buffer_elts);
221209

222210
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -228,4 +216,3 @@ void group_norm_nhwc_bwd_two_passes_sum (const Group_norm_nhwc_bwd_params&, cud
228216
void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params&, cudaStream_t);
229217

230218
////////////////////////////////////////////////////////////////////////////////////////////////////
231-

0 commit comments

Comments
 (0)