Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d86d1b0

Browse files
authored
Initial check-in of the transducer extensions (#1069)
* Initial check-in of the transducer extension. * Added more comments to help explain the code * Corrected minor typos * 1. Renamed variable in tests to match the extension 2. Disabled ninja build option
1 parent e2083df commit d86d1b0

10 files changed

Lines changed: 2093 additions & 0 deletions

File tree

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <torch/extension.h>
2+
#include <ATen/Functions.h>
3+
4+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
5+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
6+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
7+
8+
torch::Tensor transducer_joint_cuda_forward(
9+
torch::Tensor f,
10+
torch::Tensor g,
11+
torch::Tensor fLen,
12+
torch::Tensor gLen,
13+
torch::Tensor batchOffset,
14+
int64_t packedBatch,
15+
int opt,
16+
bool packOutput,
17+
int tileSize);
18+
19+
20+
std::vector<torch::Tensor> transducer_joint_cuda_backward(
21+
torch::Tensor grad,
22+
torch::Tensor fLen,
23+
torch::Tensor gLen,
24+
torch::Tensor batchOffset,
25+
int maxFLen,
26+
int maxGLen,
27+
bool packOutput);
28+
29+
torch::Tensor transducer_joint_forward(
30+
torch::Tensor f,
31+
torch::Tensor g,
32+
torch::Tensor fLen,
33+
torch::Tensor gLen,
34+
torch::Tensor batchOffset,
35+
int64_t packedBatch,
36+
int opt,
37+
bool packOutput,
38+
int tileSize) {
39+
CHECK_INPUT(f);
40+
CHECK_INPUT(g);
41+
CHECK_INPUT(fLen);
42+
CHECK_INPUT(gLen);
43+
if (packOutput)
44+
CHECK_INPUT(batchOffset);
45+
return transducer_joint_cuda_forward(
46+
f,
47+
g,
48+
fLen,
49+
gLen,
50+
batchOffset,
51+
packedBatch,
52+
opt,
53+
packOutput,
54+
tileSize);
55+
}
56+
57+
std::vector<torch::Tensor> transducer_joint_backward(
58+
torch::Tensor grad,
59+
torch::Tensor fLen,
60+
torch::Tensor gLen,
61+
torch::Tensor batchOffset,
62+
int maxFLen,
63+
int maxGLen,
64+
bool packOutput) {
65+
CHECK_INPUT(grad);
66+
CHECK_INPUT(fLen);
67+
CHECK_INPUT(gLen);
68+
if (packOutput)
69+
CHECK_INPUT(batchOffset);
70+
return transducer_joint_cuda_backward(
71+
grad,
72+
fLen,
73+
gLen,
74+
batchOffset,
75+
maxFLen,
76+
maxGLen,
77+
packOutput);
78+
}
79+
80+
81+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
82+
m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)");
83+
m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)");
84+
}

0 commit comments

Comments
 (0)