-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathmask_softmax_dropout_func.py
More file actions
97 lines (85 loc) · 3.22 KB
/
mask_softmax_dropout_func.py
File metadata and controls
97 lines (85 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import fast_multihead_attn
class MaskSoftmaxDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
from apex import deprecated_warning
deprecated_warning(
"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. "
"We encourage you to migrate to PyTorch native MultiheadAttention"
"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html"
)
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
if mask_additive:
dropout_results, dropout_mask, softmax_results = (
fast_multihead_attn.additive_mask_softmax_dropout_forward(
use_mask,
is_training,
heads,
inputs,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
)
# fast_additive_mask_softmax_dropout.forward( \
else:
dropout_results, dropout_mask, softmax_results = (
fast_multihead_attn.mask_softmax_dropout_forward(
use_mask,
is_training,
heads,
inputs,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
)
# fast_mask_softmax_dropout.forward( \
ctx.save_for_backward(
use_mask_t,
heads_t,
softmax_results,
dropout_mask,
pad_mask if use_mask else null_tensor,
mask_additive_t,
dropout_prob_t,
)
return dropout_results.detach()
@staticmethod
def backward(ctx, output_grads):
(
use_mask_t,
heads_t,
softmax_results,
dropout_mask,
pad_mask,
mask_additive_t,
dropout_prob_t,
) = ctx.saved_tensors
if mask_additive_t[0]:
input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward(
use_mask_t[0],
heads_t[0],
output_grads,
softmax_results,
dropout_mask,
dropout_prob_t[0],
)
# fast_additive_mask_softmax_dropout.backward( \
else:
input_grads = fast_multihead_attn.mask_softmax_dropout_backward(
use_mask_t[0],
heads_t[0],
output_grads,
softmax_results,
dropout_mask,
pad_mask,
dropout_prob_t[0],
)
# fast_mask_softmax_dropout.backward( \
return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply