Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 623315a

Browse files
authored
ConvFrozenScaleBiasReLU fusion (#1608)
* Cherry pick changes to ConvScaleBiasReLU fusion * Fix testbench * Add missing conv_cscale_cbias_relu_forward * Fix bug in setOperationGraph * Remove manual cuDNN heuristics knobs * Use torch.testing.assert_close for tensor comparison * Return at::Tensor instead of vector, add debug msg --------- Co-authored-by: Jaemin Choi <[email protected]>
1 parent b1c7600 commit 623315a

4 files changed

Lines changed: 570 additions & 14 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
1+
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU, ConvFrozenScaleBiasReLU
22

apex/contrib/conv_bias_relu/conv_bias_relu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,30 @@ def backward(ctx, grad_output):
7575
return grads[0], grads[1], grads[2], None, None
7676

7777

78+
class ConvFrozenScaleBiasReLU_(torch.autograd.Function):
79+
@staticmethod
80+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
81+
def forward(ctx, x, weight, scale, bias, padding, stride):
82+
output = fused_conv_bias_relu.forward_cscale_cbias_relu([x, weight, scale, bias], padding, stride)
83+
ctx.save_for_backward(x, weight, scale, output)
84+
ctx.padding = padding
85+
ctx.stride = stride
86+
87+
return output
88+
89+
@staticmethod
90+
@torch.cuda.amp.custom_bwd
91+
def backward(ctx, grad_output):
92+
bwd_args = [*ctx.saved_tensors, grad_output]
93+
padding = ctx.padding
94+
stride = ctx.stride
95+
grads = fused_conv_bias_relu.backward_cscale_cbias_relu(bwd_args, padding, stride)
96+
97+
return grads[0], grads[1], None, None, None, None
98+
99+
78100
ConvBiasReLU = ConvBiasReLU_.apply
79101
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
80102
ConvBias = ConvBias_.apply
103+
ConvFrozenScaleBiasReLU = ConvFrozenScaleBiasReLU_.apply
81104

0 commit comments

Comments
 (0)