-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathmlp.py
More file actions
87 lines (70 loc) · 2.7 KB
/
mlp.py
File metadata and controls
87 lines (70 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from copy import copy
import math
import torch
from torch import nn
from apex._autocast_utils import _cast_if_autocast_enabled
import mlp_cuda
class MlpFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, bias, activation, *args):
output = mlp_cuda.forward(bias, activation, args)
ctx.save_for_backward(*args)
ctx.outputs = output
ctx.bias = bias
ctx.activation = activation
return output[0]
@staticmethod
def backward(ctx, grad_o):
grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)
del ctx.outputs
return (None, None, *grads)
def mlp_function(bias, activation, *args):
autocast_args = _cast_if_autocast_enabled(bias, activation, *args)
return MlpFunction.apply(*autocast_args)
class MLP(torch.nn.Module):
"""Launch MLP in C++
Args:
mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024
bias (bool): Default True:
relu (bool): Default True
"""
def __init__(self, mlp_sizes, bias=True, activation="relu"):
super().__init__()
self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes)
self.bias = 1 if bias else 0
if activation == "none":
self.activation = 0
elif activation == "relu":
self.activation = 1
elif activation == "sigmoid":
self.activation = 2
else:
raise TypeError("activation must be relu or none.")
self.weights = []
self.biases = []
for i in range(self.num_layers):
w = torch.nn.Parameter(torch.empty(mlp_sizes[i + 1], mlp_sizes[i]))
self.weights.append(w)
name = "weight_{}".format(i)
setattr(self, name, w)
if self.bias:
b = torch.nn.Parameter(torch.empty(mlp_sizes[i + 1]))
self.biases.append(b)
name = "bias_{}".format(i)
setattr(self, name, b)
self.reset_parameters()
def reset_parameters(self):
for weight in self.weights:
dimsum = weight.size(0) + weight.size(1)
std = math.sqrt(2.0 / float(dimsum))
nn.init.normal_(weight, 0.0, std)
if self.bias:
for bias in self.biases:
std = math.sqrt(1.0 / float(bias.size(0)))
nn.init.normal_(bias, 0.0, std)
def forward(self, input):
return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)
def extra_repr(self):
s = f"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}"
return s