-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathmlp.cpp
More file actions
112 lines (91 loc) · 4.65 KB
/
mlp.cpp
File metadata and controls
112 lines (91 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <stdio.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);
template <typename T>
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);
template <typename T>
int mlp_fp(T* X, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T** BPtr, T* Y,
T* reserved_space, int use_bias, int activation, void* lt_workspace);
template <typename T>
int mlp_bp(T* X, T* Y, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T* dY,
T* reserved_space, T* work_space, T* dX, T** dwPtr, T** dbPtr, bool requires_grad, int use_bias,
int activation);
std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1);
std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0));
}
auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({static_cast<long>(reserved_size)}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].scalar_type(), "mlp_forward", [&] {
std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
if (use_bias) {
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
}
}
[[maybe_unused]] auto result = mlp_fp<scalar_t>(inputs[0].data_ptr<scalar_t>(), input_features, batch_size,
w_ptr.data(), num_layers, output_features.data(), b_ptr.data(),
out.data_ptr<scalar_t>(), reserved_space.data_ptr<scalar_t>(),
use_bias, activation, (void*)(lt_workspace.data_ptr<scalar_t>()));
});
return {out, reserved_space};
}
std::vector<at::Tensor> mlp_backward(int use_bias, int activation, at::Tensor grad_o,
std::vector<at::Tensor> fprop_outputs, std::vector<at::Tensor> inputs) {
auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1);
bool requires_grad = inputs[0].requires_grad();
std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0));
}
// create outputs, length of inputs
std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].scalar_type(), "mlp_backward", [&] {
std::vector<scalar_t*> w_ptr;
for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
}
std::vector<scalar_t*> outputs_ptr;
for (int i = 0; i < inputs.size(); i++) {
outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());
}
auto work_size = get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
// auto work_space = at::empty({work_size*4}, at::kByte);
auto work_space = at::empty({static_cast<long>(work_size / sizeof(scalar_t))}, inputs[0].type());
[[maybe_unused]] auto result = mlp_bp<scalar_t>(
inputs[0].data_ptr<scalar_t>(), fprop_outputs[0].data_ptr<scalar_t>(), input_features, batch_size, w_ptr.data(),
num_layers, output_features.data(), grad_o.contiguous().data_ptr<scalar_t>(),
fprop_outputs[1].data_ptr<scalar_t>(), work_space.data_ptr<scalar_t>(), outputs_ptr[0], outputs_ptr.data() + 1,
outputs_ptr.data() + 1 + num_layers, requires_grad, use_bias, activation);
});
return outputs;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &mlp_forward, "MLP forward", py::call_guard<py::gil_scoped_release>());
m.def("backward", &mlp_backward, "MLP backward", py::call_guard<py::gil_scoped_release>());
}