Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 700d682

Browse files
seryilmazSukru Eryilmaz
andauthored
fix default mode missing additive mask option (#924)
Co-authored-by: Sukru Eryilmaz <[email protected]>
1 parent 459de22 commit 700d682

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

apex/contrib/multihead_attn/self_multihead_attn_func.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
66
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
77
input_weights, output_weights,
88
input_biases, output_biases,
9-
mask, dropout_prob):
9+
mask, is_additive_mask, dropout_prob):
1010
use_biases_t = torch.tensor([input_biases is not None])
1111
heads_t = torch.tensor([heads])
1212
scale_t = torch.tensor([scale])
@@ -60,8 +60,11 @@ def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
6060
batches,seql_q,seql_k = matmul1_results.size()
6161
seqs = int(batches / heads)
6262
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
63-
mask = mask.to(torch.bool)
64-
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
63+
if is_additive_mask:
64+
matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
65+
else:
66+
mask = mask.to(torch.bool)
67+
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
6568
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
6669

6770
softmax_results = F.softmax(matmul1_results, dim=-1)

0 commit comments

Comments
 (0)