@@ -3554,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_
35543554 return outputs;
35553555}
35563556
3557- at::Tensor bottleneck_backward_grad_out2 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
3558-
3559- bool requires_grad = inputs[0 ].requires_grad ();
3560-
3561- std::cout << std::fixed;
3562- auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
3557+ void bottleneck_backward_wgrad3 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
35633558
35643559 // dconv3+drelu2+dscale2
35653560 at::Half* conv_in = inputs[13 ].data_ptr <at::Half>();
35663561 at::Half* dy3 = inputs[10 ].data_ptr <at::Half>();
35673562
3568- DEBUG_MSG (" [DEBUG] new dconv3 : " << inputs[10 ].to (at::kFloat ).sum ().item <float >());
3569-
35703563 // wgrad
35713564 auto wgrad3 = outputs[3 ];
35723565 at::Half* dw3 = wgrad3.data_ptr <at::Half>();
@@ -3583,6 +3576,21 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
35833576 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
35843577 DEBUG_MSG (" [DEBUG] new wgrad3 : " << wgrad3.to (at::kFloat ).sum ().item <float >());
35853578
3579+ }
3580+
3581+ at::Tensor bottleneck_backward_grad_out2 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
3582+
3583+ bool requires_grad = inputs[0 ].requires_grad ();
3584+
3585+ std::cout << std::fixed;
3586+ auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
3587+
3588+ // dconv3+drelu2+dscale2
3589+ at::Half* conv_in = inputs[13 ].data_ptr <at::Half>();
3590+ at::Half* dy3 = inputs[10 ].data_ptr <at::Half>();
3591+
3592+ DEBUG_MSG (" [DEBUG] new dconv3 : " << inputs[10 ].to (at::kFloat ).sum ().item <float >());
3593+
35863594 // dgrad
35873595 auto grad_out2 = at::empty (backward_state.outdim2 , inputs[0 ].type (), output_format);
35883596 at::Half* dy2 = grad_out2.data_ptr <at::Half>();
@@ -3769,7 +3777,7 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
37693777 return grad_out1_halo;
37703778}
37713779
3772- at::Tensor bottleneck_backward_wgrad2_pad (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
3780+ void bottleneck_backward_wgrad2_pad (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
37733781
37743782 std::cout << std::fixed;
37753783 auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
@@ -3798,11 +3806,9 @@ at::Tensor bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, st
37983806 dy2,
37993807 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
38003808 DEBUG_MSG (" [DEBUG] new wgrad2 : " << wgrad2.to (at::kFloat ).sum ().item <float >());
3801-
3802- return wgrad2;
38033809}
38043810
3805- at::Tensor bottleneck_backward_wgrad2 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
3811+ void bottleneck_backward_wgrad2 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
38063812
38073813 bool requires_grad = inputs[0 ].requires_grad ();
38083814
@@ -3832,8 +3838,6 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
38323838 dy2,
38333839 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
38343840 DEBUG_MSG (" [DEBUG] new wgrad2 : " << wgrad2.to (at::kFloat ).sum ().item <float >());
3835-
3836- return wgrad2;
38373841}
38383842
38393843// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
@@ -3876,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s
38763880 return wgrad2_halo;
38773881}
38783882
3879- void bottleneck_backward_rest (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) {
3883+ void bottleneck_backward_wgrad1 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out1) {
3884+
3885+ at::Half* x = inputs[0 ].data_ptr <at::Half>();
3886+ at::Half* dy1 = grad_out1.data_ptr <at::Half>();
3887+
3888+ // dconv1+add
3889+ // wgrad
3890+ auto wgrad1 = outputs[1 ];
3891+ at::Half* dw1 = wgrad1.data_ptr <at::Half>();
3892+ run_dconv (backward_state.dimA ,
3893+ backward_state.padA ,
3894+ backward_state.convstride1X1 ,
3895+ backward_state.dilationA ,
3896+ backward_state.filterdimA1 ,
3897+ backward_state.outdimA1 ,
3898+ CUDNN_DATA_HALF,
3899+ x,
3900+ dw1,
3901+ dy1,
3902+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
3903+
3904+ }
3905+
3906+ void bottleneck_backward_rest (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {
38803907
38813908 bool requires_grad = inputs[0 ].requires_grad ();
38823909
@@ -3974,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
39744001 dx_conv4 = inputs[11 ].data_ptr <at::Half>();
39754002 }
39764003
3977- // dconv1+add
3978- // wgrad
3979- auto wgrad1 = outputs[1 ];
3980- at::Half* dw1 = wgrad1.data_ptr <at::Half>();
3981- run_dconv (backward_state.dimA ,
3982- backward_state.padA ,
3983- backward_state.convstride1X1 ,
3984- backward_state.dilationA ,
3985- backward_state.filterdimA1 ,
3986- backward_state.outdimA1 ,
3987- CUDNN_DATA_HALF,
3988- x,
3989- dw1,
3990- dy1,
3991- CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
3992-
39934004 // dgrad
39944005 w = inputs[1 ].data_ptr <at::Half>();
39954006 auto grad_x = outputs[0 ];
@@ -4056,5 +4067,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
40564067 m.def (" backward_wgrad2_pad" , &bottleneck_backward_wgrad2_pad, " Bottleneck block backward" );
40574068 m.def (" backward_wgrad2" , &bottleneck_backward_wgrad2, " Bottleneck block backward" );
40584069 m.def (" backward_wgrad2_halo" , &bottleneck_backward_wgrad2_halo, " Bottleneck block backward" );
4070+ m.def (" backward_wgrad3" , &bottleneck_backward_wgrad3, " Bottleneck block backward" );
4071+ m.def (" backward_wgrad1" , &bottleneck_backward_wgrad1, " Bottleneck block backward" );
40594072 m.def (" backward_rest" , &bottleneck_backward_rest, " Bottleneck block backward" );
40604073}
0 commit comments