Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5698eee

Browse files
committed
Bit faster
1 parent 140282d commit 5698eee

2 files changed

Lines changed: 60 additions & 37 deletions

File tree

apex/contrib/bottleneck/bottleneck.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,14 +448,11 @@ def backward(ctx, grad_o):
448448
t_list.append(ctx.saved_tensors[10])
449449

450450
grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
451+
wgrad3_stream = torch.cuda.Stream()
452+
wgrad3_stream.wait_stream(torch.cuda.current_stream())
451453
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
452454
wgrad2_stream = torch.cuda.Stream()
453455
wgrad2_stream.wait_stream(torch.cuda.current_stream())
454-
with torch.cuda.stream(wgrad2_stream):
455-
if ctx.spatial_group_size > 1:
456-
wgrad2 = fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
457-
else:
458-
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
459456
# do halo exchange of grad_out2 here
460457
# compute halo cells for grad_out1
461458
if ctx.spatial_group_size > 1:
@@ -576,8 +573,21 @@ def backward(ctx, grad_o):
576573
if ctx.spatial_group_rank > 0:
577574
torch.cuda.current_stream().wait_stream(ctx.stream1)
578575

579-
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
576+
wgrad1_stream = torch.cuda.Stream()
577+
wgrad1_stream.wait_stream(torch.cuda.current_stream())
578+
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1)
579+
with torch.cuda.stream(wgrad3_stream):
580+
fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
581+
with torch.cuda.stream(wgrad2_stream):
582+
if ctx.spatial_group_size > 1:
583+
fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
584+
else:
585+
fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
586+
with torch.cuda.stream(wgrad1_stream):
587+
fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1)
588+
torch.cuda.current_stream().wait_stream(wgrad3_stream)
580589
torch.cuda.current_stream().wait_stream(wgrad2_stream)
590+
torch.cuda.current_stream().wait_stream(wgrad1_stream)
581591

582592
return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
583593

apex/contrib/csrc/bottleneck/bottleneck.cpp

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3554,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_
35543554
return outputs;
35553555
}
35563556

3557-
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
3558-
3559-
bool requires_grad = inputs[0].requires_grad();
3560-
3561-
std::cout << std::fixed;
3562-
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
3557+
void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
35633558

35643559
// dconv3+drelu2+dscale2
35653560
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
35663561
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
35673562

3568-
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
3569-
35703563
// wgrad
35713564
auto wgrad3 = outputs[3];
35723565
at::Half* dw3 = wgrad3.data_ptr<at::Half>();
@@ -3583,6 +3576,21 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
35833576
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
35843577
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
35853578

3579+
}
3580+
3581+
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
3582+
3583+
bool requires_grad = inputs[0].requires_grad();
3584+
3585+
std::cout << std::fixed;
3586+
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
3587+
3588+
// dconv3+drelu2+dscale2
3589+
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
3590+
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
3591+
3592+
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
3593+
35863594
// dgrad
35873595
auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);
35883596
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
@@ -3769,7 +3777,7 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
37693777
return grad_out1_halo;
37703778
}
37713779

3772-
at::Tensor bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
3780+
void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
37733781

37743782
std::cout << std::fixed;
37753783
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
@@ -3798,11 +3806,9 @@ at::Tensor bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, st
37983806
dy2,
37993807
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
38003808
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
3801-
3802-
return wgrad2;
38033809
}
38043810

3805-
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
3811+
void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
38063812

38073813
bool requires_grad = inputs[0].requires_grad();
38083814

@@ -3832,8 +3838,6 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
38323838
dy2,
38333839
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
38343840
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
3835-
3836-
return wgrad2;
38373841
}
38383842

38393843
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
@@ -3876,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s
38763880
return wgrad2_halo;
38773881
}
38783882

3879-
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) {
3883+
void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out1) {
3884+
3885+
at::Half* x = inputs[0].data_ptr<at::Half>();
3886+
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
3887+
3888+
// dconv1+add
3889+
// wgrad
3890+
auto wgrad1 = outputs[1];
3891+
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
3892+
run_dconv(backward_state.dimA,
3893+
backward_state.padA,
3894+
backward_state.convstride1X1,
3895+
backward_state.dilationA,
3896+
backward_state.filterdimA1,
3897+
backward_state.outdimA1,
3898+
CUDNN_DATA_HALF,
3899+
x,
3900+
dw1,
3901+
dy1,
3902+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
3903+
3904+
}
3905+
3906+
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {
38803907

38813908
bool requires_grad = inputs[0].requires_grad();
38823909

@@ -3974,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
39744001
dx_conv4 = inputs[11].data_ptr<at::Half>();
39754002
}
39764003

3977-
// dconv1+add
3978-
// wgrad
3979-
auto wgrad1 = outputs[1];
3980-
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
3981-
run_dconv(backward_state.dimA,
3982-
backward_state.padA,
3983-
backward_state.convstride1X1,
3984-
backward_state.dilationA,
3985-
backward_state.filterdimA1,
3986-
backward_state.outdimA1,
3987-
CUDNN_DATA_HALF,
3988-
x,
3989-
dw1,
3990-
dy1,
3991-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
3992-
39934004
// dgrad
39944005
w = inputs[1].data_ptr<at::Half>();
39954006
auto grad_x = outputs[0];
@@ -4056,5 +4067,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
40564067
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
40574068
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
40584069
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
4070+
m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward");
4071+
m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward");
40594072
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
40604073
}

0 commit comments

Comments
 (0)