Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a0d99fd

Browse files
committed
Fixed weight init for fused weight matrices in fused MHA by adding correct gain factor.
1 parent 1ff54b8 commit a0d99fd

2 files changed

Lines changed: 14 additions & 2 deletions

File tree

apex/contrib/multihead_attn/encdec_multihead_attn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch import nn
35
from torch.nn import Parameter
@@ -76,7 +78,11 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_ad
7678

7779
def reset_parameters(self):
7880
nn.init.xavier_uniform_(self.in_proj_weight_q)
79-
nn.init.xavier_uniform_(self.in_proj_weight_kv)
81+
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
82+
# initialized like a [hidden, hidden] matrix.
83+
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
84+
# therefore xavier_uniform gain should be set to sqrt(1.5).
85+
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
8086
nn.init.xavier_uniform_(self.out_proj_weight)
8187
if self.bias:
8288
nn.init.constant_(self.in_proj_bias_q, 0.)

apex/contrib/multihead_attn/self_multihead_attn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch import nn
35
from torch.nn import Parameter
@@ -98,7 +100,11 @@ def reset_parameters(self):
98100
nn.init.xavier_uniform_(self.k_weight)
99101
nn.init.xavier_uniform_(self.v_weight)
100102
else:
101-
nn.init.xavier_uniform_(self.in_proj_weight)
103+
# in_proj_weight has shape [3 * hidden, hidden] but it should be
104+
# initialized like a [hidden, hidden] matrix.
105+
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
106+
# therefore xavier_uniform gain should be set to sqrt(2).
107+
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
102108
nn.init.xavier_uniform_(self.out_proj_weight)
103109
if self.bias:
104110
if self.separate_qkv_params:

0 commit comments

Comments
 (0)