Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8cdcc82

Browse files
committed
Bug fixes
1 parent 67a0ffc commit 8cdcc82

2 files changed

Lines changed: 70 additions & 1 deletion

File tree

apex/contrib/bottleneck/bottleneck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def backward(ctx, grad_o):
354354
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:]
355355
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
356356
fat_halo[:,2:,:,:].copy_(btm_halo)
357-
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2,:,:])
357+
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
358358
relu_halo[:,2:,:,:].zero_()
359359
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
360360
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]

apex/contrib/csrc/bottleneck/bottleneck.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,6 +2158,73 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
21582158
return grad_out2;
21592159
}
21602160

2161+
// compute dgrad of 3x3 convolution without fusing with drelu and dscale
2162+
at::Tensor bottleneck_backward_dgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
2163+
2164+
bool requires_grad = inputs[0].requires_grad();
2165+
2166+
std::cout << std::fixed;
2167+
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
2168+
2169+
// dgrad
2170+
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
2171+
2172+
// dgrad
2173+
auto dgrad1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
2174+
at::Half* dy1 = dgrad1.data_ptr<at::Half>();
2175+
at::Half* w = inputs[2].data_ptr<at::Half>();
2176+
at::Half* z = inputs[4].data_ptr<at::Half>();
2177+
2178+
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
2179+
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
2180+
2181+
// dgrad
2182+
run_dconv(backward_state.outdimA1,
2183+
backward_state.padA1,
2184+
backward_state.convstrideA,
2185+
backward_state.dilationA,
2186+
backward_state.filterdimA2,
2187+
backward_state.outdimA2,
2188+
CUDNN_DATA_HALF,
2189+
dy1,
2190+
w,
2191+
dy2,
2192+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
2193+
2194+
return dgrad1;
2195+
}
2196+
2197+
at::Tensor bottleneck_backward_dgrad1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo) {
2198+
2199+
bool requires_grad = inputs[0].requires_grad();
2200+
2201+
std::cout << std::fixed;
2202+
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
2203+
2204+
// dgrad
2205+
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
2206+
2207+
// dgrad
2208+
auto dgrad1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format);
2209+
at::Half* dy1h = dgrad1_halo.data_ptr<at::Half>();
2210+
at::Half* w = inputs[2].data_ptr<at::Half>();
2211+
2212+
// dgrad
2213+
run_dconv(backward_state.outdimA1h,
2214+
backward_state.padA1,
2215+
backward_state.convstrideA,
2216+
backward_state.dilationA,
2217+
backward_state.filterdimA2,
2218+
backward_state.outdimA2h,
2219+
CUDNN_DATA_HALF,
2220+
dy1h,
2221+
w,
2222+
dy2h,
2223+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
2224+
2225+
return dgrad1_halo;
2226+
}
2227+
21612228
at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
21622229

21632230
bool requires_grad = inputs[0].requires_grad();
@@ -2480,6 +2547,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
24802547
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
24812548
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
24822549
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
2550+
m.def("backward_dgrad1", &bottleneck_backward_dgrad1, "Bottleneck block backward");
2551+
m.def("backward_dgrad1_halo", &bottleneck_backward_dgrad1_halo, "Bottleneck block backward");
24832552
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
24842553
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
24852554
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");

0 commit comments

Comments
 (0)