@@ -2158,6 +2158,73 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
21582158 return grad_out2;
21592159}
21602160
2161+ // compute dgrad of 3x3 convolution without fusing with drelu and dscale
2162+ at::Tensor bottleneck_backward_dgrad1 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
2163+
2164+ bool requires_grad = inputs[0 ].requires_grad ();
2165+
2166+ std::cout << std::fixed;
2167+ auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
2168+
2169+ // dgrad
2170+ at::Half* dy2 = grad_out2.data_ptr <at::Half>();
2171+
2172+ // dgrad
2173+ auto dgrad1 = at::empty (backward_state.outdim1 , inputs[0 ].type (), output_format);
2174+ at::Half* dy1 = dgrad1.data_ptr <at::Half>();
2175+ at::Half* w = inputs[2 ].data_ptr <at::Half>();
2176+ at::Half* z = inputs[4 ].data_ptr <at::Half>();
2177+
2178+ at::Half* relu1 = inputs[12 ].data_ptr <at::Half>();
2179+ // printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
2180+
2181+ // dgrad
2182+ run_dconv (backward_state.outdimA1 ,
2183+ backward_state.padA1 ,
2184+ backward_state.convstrideA ,
2185+ backward_state.dilationA ,
2186+ backward_state.filterdimA2 ,
2187+ backward_state.outdimA2 ,
2188+ CUDNN_DATA_HALF,
2189+ dy1,
2190+ w,
2191+ dy2,
2192+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
2193+
2194+ return dgrad1;
2195+ }
2196+
2197+ at::Tensor bottleneck_backward_dgrad1_halo (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo) {
2198+
2199+ bool requires_grad = inputs[0 ].requires_grad ();
2200+
2201+ std::cout << std::fixed;
2202+ auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
2203+
2204+ // dgrad
2205+ at::Half* dy2h = grad_out2_halo.data_ptr <at::Half>();
2206+
2207+ // dgrad
2208+ auto dgrad1_halo = at::empty (backward_state.outdim1h , inputs[0 ].type (), output_format);
2209+ at::Half* dy1h = dgrad1_halo.data_ptr <at::Half>();
2210+ at::Half* w = inputs[2 ].data_ptr <at::Half>();
2211+
2212+ // dgrad
2213+ run_dconv (backward_state.outdimA1h ,
2214+ backward_state.padA1 ,
2215+ backward_state.convstrideA ,
2216+ backward_state.dilationA ,
2217+ backward_state.filterdimA2 ,
2218+ backward_state.outdimA2h ,
2219+ CUDNN_DATA_HALF,
2220+ dy1h,
2221+ w,
2222+ dy2h,
2223+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
2224+
2225+ return dgrad1_halo;
2226+ }
2227+
21612228at::Tensor bottleneck_backward_grad_out1 (bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
21622229
21632230 bool requires_grad = inputs[0 ].requires_grad ();
@@ -2480,6 +2547,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
24802547 m.def (" backward_grad_out2" , &bottleneck_backward_grad_out2, " Bottleneck block backward" );
24812548 m.def (" backward_grad_out1" , &bottleneck_backward_grad_out1, " Bottleneck block backward" );
24822549 m.def (" backward_grad_out1_halo" , &bottleneck_backward_grad_out1_halo, " Bottleneck block backward" );
2550+ m.def (" backward_dgrad1" , &bottleneck_backward_dgrad1, " Bottleneck block backward" );
2551+ m.def (" backward_dgrad1_halo" , &bottleneck_backward_dgrad1_halo, " Bottleneck block backward" );
24832552 m.def (" backward_wgrad2" , &bottleneck_backward_wgrad2, " Bottleneck block backward" );
24842553 m.def (" backward_wgrad2_halo" , &bottleneck_backward_wgrad2_halo, " Bottleneck block backward" );
24852554 m.def (" backward_rest" , &bottleneck_backward_rest, " Bottleneck block backward" );
0 commit comments