@@ -6,7 +6,7 @@ class SelfAttnFunc(torch.autograd.Function):
66 def forward (ctx , use_time_mask , is_training , heads , scale , inputs ,
77 input_weights , output_weights ,
88 input_biases , output_biases ,
9- mask , dropout_prob ):
9+ mask , is_additive_mask , dropout_prob ):
1010 use_biases_t = torch .tensor ([input_biases is not None ])
1111 heads_t = torch .tensor ([heads ])
1212 scale_t = torch .tensor ([scale ])
@@ -60,8 +60,11 @@ def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
6060 batches ,seql_q ,seql_k = matmul1_results .size ()
6161 seqs = int (batches / heads )
6262 matmul1_results = matmul1_results .view (seqs , heads , seql_q , seql_k )
63- mask = mask .to (torch .bool )
64- matmul1_results = matmul1_results .masked_fill_ (mask .unsqueeze (1 ).unsqueeze (2 ), float ('-inf' ))
63+ if is_additive_mask :
64+ matmul1_results = matmul1_results + mask .unsqueeze (1 ).unsqueeze (2 )
65+ else :
66+ mask = mask .to (torch .bool )
67+ matmul1_results = matmul1_results .masked_fill_ (mask .unsqueeze (1 ).unsqueeze (2 ), float ('-inf' ))
6568 matmul1_results = matmul1_results .view (seqs * heads , seql_q , seql_k )
6669
6770 softmax_results = F .softmax (matmul1_results , dim = - 1 )
0 commit comments