1- /* **************************************************************************************************
2- * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved.
3- *
4- * Redistribution and use in source and binary forms, with or without modification, are not permit-
5- * ted.
6- *
7- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
8- * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
9- * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
10- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
11- * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
12- * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
13- * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
14- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15- *
16- **************************************************************************************************/
1+ /*
2+ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+ * SPDX-License-Identifier: BSD-3-Clause
4+ */
175#include < traits.h>
186#include < group_norm_nhwc.h>
197#include < group_norm_nhwc_fwd_one_pass.h>
@@ -48,9 +36,9 @@ float inline unpack(const float& x) {
4836
4937template <typename T>
5038void check_results (const char *name,
51- const T *out,
52- const T *ref,
53- size_t elts,
39+ const T *out,
40+ const T *ref,
41+ size_t elts,
5442 float tol) {
5543
5644 // The number of errors.
@@ -76,14 +64,14 @@ void check_results(const char *name,
7664 float abs_b = fabsf (b);
7765
7866 // Compute the error.
79- float den = abs_a + abs_b;
67+ float den = abs_a + abs_b;
8068 // Is one of the quantities very small?
8169 bool is_small = abs_a <= tol || abs_b <= tol || den <= tol;
8270 // The error.
8371 float err = is_small ? fabsf (a-b) : fabsf (a-b) / den;
8472 // Is the result ok?
8573 bool ok = !isnan (a) && !isnan (b) && err <= tol;
86-
74+
8775 // Print the error.
8876 if ( !ok && (failed < 10 || err > max_err) ) {
8977
@@ -146,19 +134,19 @@ template void check_results(const char *name, const float *out, const float *ref
146134
147135// //////////////////////////////////////////////////////////////////////////////////////////////////
148136
149- static void group_norm_nhwc_bwd_ (void *dx_h,
137+ static void group_norm_nhwc_bwd_ (void *dx_h,
150138 float *dgamma_h,
151139 float *dbeta_h,
152- const void *dy_h,
153- const void *x_h,
140+ const void *dy_h,
141+ const void *x_h,
154142 const float *gamma_h,
155143 const float *beta_h,
156- const float2 *sums_h,
144+ const float2 *sums_h,
157145 float epsilon,
158- int n,
159- int h,
160- int w,
161- int c,
146+ int n,
147+ int h,
148+ int w,
149+ int c,
162150 int groups,
163151 bool with_swish,
164152 bool use_fp32,
@@ -259,7 +247,7 @@ static void group_norm_nhwc_bwd_(void *dx_h,
259247 } // ii
260248 } // wi
261249 } // hi
262-
250+
263251 mean_1 *= rcp_hwc_per_group;
264252 mean_2 *= rcp_hwc_per_group;
265253
@@ -342,15 +330,15 @@ static void group_norm_nhwc_bwd_(void *dx_h,
342330
343331// //////////////////////////////////////////////////////////////////////////////////////////////////
344332
345- static void group_norm_nhwc_fwd_ (void *y_h,
346- const void *x_h,
347- const float *gamma_h,
348- const float *beta_h,
333+ static void group_norm_nhwc_fwd_ (void *y_h,
334+ const void *x_h,
335+ const float *gamma_h,
336+ const float *beta_h,
349337 float epsilon,
350- int n,
351- int h,
352- int w,
353- int c,
338+ int n,
339+ int h,
340+ int w,
341+ int c,
354342 int groups,
355343 bool with_swish,
356344 bool use_fp32,
@@ -602,7 +590,7 @@ int main(int argc, char **argv) {
602590 printf (" mode.........................: bwd\n " );
603591 } else if ( mode == Mode::FWD_INFERENCE ) {
604592 printf (" mode.........................: fwd inference\n " );
605- } else if ( mode == Mode::FWD_TRAINING ) {
593+ } else if ( mode == Mode::FWD_TRAINING ) {
606594 printf (" mode.........................: fwd training\n " );
607595 } else {
608596 assert (false );
@@ -672,7 +660,7 @@ int main(int argc, char **argv) {
672660 }
673661
674662 // Allocate the src/dst on the host for the gradients (bwd).
675- void *dx_h = nullptr , *dy_h = nullptr ;
663+ void *dx_h = nullptr , *dy_h = nullptr ;
676664 if ( mode == Mode::BWD ) {
677665 dx_h = malloc (x_sz);
678666 dy_h = malloc (x_sz);
@@ -798,20 +786,20 @@ int main(int argc, char **argv) {
798786
799787 // Compute the golden reference on the host.
800788 if (!skip_checks) {
801- if ( mode == Mode::BWD ) {
802- group_norm_nhwc_bwd_ (dx_ref_h,
789+ if ( mode == Mode::BWD ) {
790+ group_norm_nhwc_bwd_ (dx_ref_h,
803791 dgamma_ref_h,
804792 dbeta_ref_h,
805793 dy_h,
806- x_h,
807- gamma_h,
794+ x_h,
795+ gamma_h,
808796 beta_h,
809797 sums_h,
810- epsilon,
811- n,
812- h,
813- w,
814- c,
798+ epsilon,
799+ n,
800+ h,
801+ w,
802+ c,
815803 groups,
816804 with_swish,
817805 use_fp32,
@@ -823,32 +811,32 @@ int main(int argc, char **argv) {
823811
824812 // Copy to the device.
825813 CHECK_CUDA (cudaMemcpyAsync (x_d, x_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault));
826- CHECK_CUDA (cudaMemcpyAsync (gamma_d,
827- gamma_h,
828- gamma_sz,
829- cudaMemcpyHostToDevice,
814+ CHECK_CUDA (cudaMemcpyAsync (gamma_d,
815+ gamma_h,
816+ gamma_sz,
817+ cudaMemcpyHostToDevice,
830818 cudaStreamDefault));
831- CHECK_CUDA (cudaMemcpyAsync (beta_d,
832- beta_h,
833- gamma_sz,
834- cudaMemcpyHostToDevice,
819+ CHECK_CUDA (cudaMemcpyAsync (beta_d,
820+ beta_h,
821+ gamma_sz,
822+ cudaMemcpyHostToDevice,
835823 cudaStreamDefault));
836824
837825 if ( mode == Mode::BWD ) {
838- CHECK_CUDA (cudaMemcpyAsync (dy_d,
839- dy_h,
840- x_sz,
841- cudaMemcpyHostToDevice,
826+ CHECK_CUDA (cudaMemcpyAsync (dy_d,
827+ dy_h,
828+ x_sz,
829+ cudaMemcpyHostToDevice,
842830 cudaStreamDefault));
843831
844832 // // DEBUG.
845833 // printf("sums_h[0] = %8.3f, %8.3f\n", sums_h[0].x, sums_h[0].y);
846834 // // END OF DEBUG.
847835
848- CHECK_CUDA (cudaMemcpyAsync (sums_d,
849- sums_h,
850- sums_sz,
851- cudaMemcpyHostToDevice,
836+ CHECK_CUDA (cudaMemcpyAsync (sums_d,
837+ sums_h,
838+ sums_sz,
839+ cudaMemcpyHostToDevice,
852840 cudaStreamDefault));
853841 }
854842
@@ -878,7 +866,7 @@ int main(int argc, char **argv) {
878866 }();
879867
880868 // Initialize the parameters.
881- if ( mode == Mode::BWD ) {
869+ if ( mode == Mode::BWD ) {
882870 params_bwd.dx = dx_d;
883871 params_bwd.dgamma = dgamma_d;
884872 params_bwd.dbeta = dbeta_d;
@@ -914,30 +902,30 @@ int main(int argc, char **argv) {
914902 // The number of barriers.
915903 size_t barriers_elts = 0 ;
916904 // The number of elements in the reduction buffer.
917- size_t red_buffer_elts = 0 ;
905+ size_t red_buffer_elts = 0 ;
918906 // The number of elements in the reduction buffer that must be zeroed.
919- size_t zeroed_red_buffer_elts = 0 ;
907+ size_t zeroed_red_buffer_elts = 0 ;
920908
921909 // Finalize the parameters.
922910 dim3 grid;
923911 if ( mode == Mode::BWD && use_one_pass ) {
924- group_norm_nhwc_bwd_one_pass_setup (params_bwd,
925- barriers_elts,
926- red_buffer_elts,
912+ group_norm_nhwc_bwd_one_pass_setup (params_bwd,
913+ barriers_elts,
914+ red_buffer_elts,
927915 zeroed_red_buffer_elts,
928- grid,
916+ grid,
929917 props);
930918 } else if ( mode == Mode::BWD ) {
931- group_norm_nhwc_bwd_two_passes_setup (params_bwd,
919+ group_norm_nhwc_bwd_two_passes_setup (params_bwd,
932920 zeroed_red_buffer_elts);
933921 } else if ( use_one_pass ) {
934- group_norm_nhwc_fwd_one_pass_setup (params_fwd,
922+ group_norm_nhwc_fwd_one_pass_setup (params_fwd,
935923 barriers_elts,
936- red_buffer_elts,
937- grid,
924+ red_buffer_elts,
925+ grid,
938926 props);
939927 } else {
940- group_norm_nhwc_fwd_two_passes_setup (params_fwd,
928+ group_norm_nhwc_fwd_two_passes_setup (params_fwd,
941929 zeroed_red_buffer_elts);
942930 }
943931
@@ -987,9 +975,9 @@ int main(int argc, char **argv) {
987975
988976 // Clear the zeroed buffer if needed.
989977 if ( zeroed_red_buffer_sz > 0 ) {
990- CHECK_CUDA (cudaMemsetAsync (zeroed_red_buffer_d_,
991- 0 ,
992- zeroed_red_buffer_sz,
978+ CHECK_CUDA (cudaMemsetAsync (zeroed_red_buffer_d_,
979+ 0 ,
980+ zeroed_red_buffer_sz,
993981 cudaStreamDefault));
994982 }
995983 if ( use_one_pass && mode == Mode::BWD ) {
@@ -1020,15 +1008,15 @@ int main(int argc, char **argv) {
10201008 // Copy the results to the host.
10211009 if ( mode == Mode::BWD ) {
10221010 CHECK_CUDA (cudaMemcpyAsync (dx_h, dx_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));
1023- CHECK_CUDA (cudaMemcpyAsync (dgamma_h,
1024- dgamma_d,
1025- gamma_sz,
1026- cudaMemcpyDeviceToHost,
1011+ CHECK_CUDA (cudaMemcpyAsync (dgamma_h,
1012+ dgamma_d,
1013+ gamma_sz,
1014+ cudaMemcpyDeviceToHost,
10271015 cudaStreamDefault));
1028- CHECK_CUDA (cudaMemcpyAsync (dbeta_h,
1029- dbeta_d,
1030- gamma_sz,
1031- cudaMemcpyDeviceToHost,
1016+ CHECK_CUDA (cudaMemcpyAsync (dbeta_h,
1017+ dbeta_d,
1018+ gamma_sz,
1019+ cudaMemcpyDeviceToHost,
10321020 cudaStreamDefault));
10331021 } else {
10341022 CHECK_CUDA (cudaMemcpyAsync (y_h, y_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));
@@ -1041,7 +1029,7 @@ int main(int argc, char **argv) {
10411029 if (!csv_output) {
10421030 if ( mode == Mode::BWD && !skip_checks ) {
10431031 if (use_fp32) {
1044- check_results<float >(" dx" , reinterpret_cast <float *>(dx_h),
1032+ check_results<float >(" dx" , reinterpret_cast <float *>(dx_h),
10451033 reinterpret_cast <float *>(dx_ref_h), x_elts, tol);
10461034 } else if (use_bf16) {
10471035 check_results<__nv_bfloat16>(" dx" , reinterpret_cast <__nv_bfloat16*>(dx_h),
@@ -1054,7 +1042,7 @@ int main(int argc, char **argv) {
10541042 check_results<float > (" dbeta" , dbeta_h, dbeta_ref_h, gamma_elts, tol);
10551043 } else if ( !skip_checks ) {
10561044 if (use_fp32) {
1057- check_results<float >(" y" , reinterpret_cast <float *>(y_h),
1045+ check_results<float >(" y" , reinterpret_cast <float *>(y_h),
10581046 reinterpret_cast <float *>(y_ref_h), x_elts, tol);
10591047 } else if (use_bf16) {
10601048 check_results<__nv_bfloat16>(" y" , reinterpret_cast <__nv_bfloat16*>(y_h),
@@ -1107,4 +1095,3 @@ int main(int argc, char **argv) {
11071095}
11081096
11091097// //////////////////////////////////////////////////////////////////////////////////////////////////
1110-
0 commit comments