Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0c2c6ee

Browse files
authored
Added more fusion and vectorized kernel for transducer (#1125)
* Added support for fused ReLU and dropout into transducer joint * Reorganized code selection path in transducer joint fwd * Added support for fused ReLU+dropout into transducer joint * Vectorize transducer loss backward with fused softmax (#3) * Nanz/transducer loss (#4) * Vectorize transducer loss backward with fused softmax * Added a predicate to avoid potential IMA * Nanz/transducer loss (#5) * Vectorize transducer loss backward with fused softmax * Added a predicate to avoid potentional IMA * Added more predicates to avoid IMAs * Updated documentations for newly added features. * Fixed a error in transducer.py
1 parent ed71996 commit 0c2c6ee

8 files changed

Lines changed: 662 additions & 185 deletions

File tree

apex/contrib/csrc/transducer/transducer_joint.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
66
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
77

8-
torch::Tensor transducer_joint_cuda_forward(
8+
std::vector<torch::Tensor> transducer_joint_cuda_forward(
99
torch::Tensor f,
1010
torch::Tensor g,
1111
torch::Tensor fLen,
@@ -14,19 +14,23 @@ torch::Tensor transducer_joint_cuda_forward(
1414
int64_t packedBatch,
1515
int opt,
1616
bool packOutput,
17+
bool relu,
18+
bool dropout,
19+
float dropoutProb,
1720
int tileSize);
1821

1922

2023
std::vector<torch::Tensor> transducer_joint_cuda_backward(
21-
torch::Tensor grad,
24+
std::vector<torch::Tensor> in,
2225
torch::Tensor fLen,
2326
torch::Tensor gLen,
2427
torch::Tensor batchOffset,
2528
int maxFLen,
2629
int maxGLen,
27-
bool packOutput);
30+
bool packOutput,
31+
float scale);
2832

29-
torch::Tensor transducer_joint_forward(
33+
std::vector<torch::Tensor> transducer_joint_forward(
3034
torch::Tensor f,
3135
torch::Tensor g,
3236
torch::Tensor fLen,
@@ -35,6 +39,9 @@ torch::Tensor transducer_joint_forward(
3539
int64_t packedBatch,
3640
int opt,
3741
bool packOutput,
42+
bool relu,
43+
bool dropout,
44+
float dropoutProb,
3845
int tileSize) {
3946
CHECK_INPUT(f);
4047
CHECK_INPUT(g);
@@ -51,30 +58,37 @@ torch::Tensor transducer_joint_forward(
5158
packedBatch,
5259
opt,
5360
packOutput,
61+
relu,
62+
dropout,
63+
dropoutProb,
5464
tileSize);
5565
}
5666

5767
std::vector<torch::Tensor> transducer_joint_backward(
58-
torch::Tensor grad,
68+
std::vector<torch::Tensor> in,
5969
torch::Tensor fLen,
6070
torch::Tensor gLen,
6171
torch::Tensor batchOffset,
6272
int maxFLen,
6373
int maxGLen,
64-
bool packOutput) {
65-
CHECK_INPUT(grad);
74+
bool packOutput,
75+
float scale) {
76+
for (auto t : in){
77+
CHECK_INPUT(t);
78+
}
6679
CHECK_INPUT(fLen);
6780
CHECK_INPUT(gLen);
6881
if (packOutput)
6982
CHECK_INPUT(batchOffset);
7083
return transducer_joint_cuda_backward(
71-
grad,
84+
in,
7285
fLen,
7386
gLen,
7487
batchOffset,
7588
maxFLen,
7689
maxGLen,
77-
packOutput);
90+
packOutput,
91+
scale);
7892
}
7993

8094

0 commit comments

Comments
 (0)