@@ -53,20 +53,20 @@ def __init__(
5353 impl == "fast" and bias
5454 ), "additive mask not supported for fast mode without bias"
5555 if separate_qkv_params :
56- self .q_weight = Parameter (torch .Tensor (embed_dim , embed_dim ))
57- self .k_weight = Parameter (torch .Tensor (embed_dim , embed_dim ))
58- self .v_weight = Parameter (torch .Tensor (embed_dim , embed_dim ))
56+ self .q_weight = Parameter (torch .empty (embed_dim , embed_dim ))
57+ self .k_weight = Parameter (torch .empty (embed_dim , embed_dim ))
58+ self .v_weight = Parameter (torch .empty (embed_dim , embed_dim ))
5959 else :
60- self .in_proj_weight = Parameter (torch .Tensor (3 * embed_dim , embed_dim ))
61- self .out_proj_weight = Parameter (torch .Tensor (embed_dim , embed_dim ))
60+ self .in_proj_weight = Parameter (torch .empty (3 * embed_dim , embed_dim ))
61+ self .out_proj_weight = Parameter (torch .empty (embed_dim , embed_dim ))
6262 if self .bias :
6363 if separate_qkv_params :
64- self .q_bias = Parameter (torch .Tensor (embed_dim ))
65- self .k_bias = Parameter (torch .Tensor (embed_dim ))
66- self .v_bias = Parameter (torch .Tensor (embed_dim ))
64+ self .q_bias = Parameter (torch .empty (embed_dim ))
65+ self .k_bias = Parameter (torch .empty (embed_dim ))
66+ self .v_bias = Parameter (torch .empty (embed_dim ))
6767 else :
68- self .in_proj_bias = Parameter (torch .Tensor (3 * embed_dim ))
69- self .out_proj_bias = Parameter (torch .Tensor (embed_dim ))
68+ self .in_proj_bias = Parameter (torch .empty (3 * embed_dim ))
69+ self .out_proj_bias = Parameter (torch .empty (embed_dim ))
7070 else :
7171 if separate_qkv_params :
7272 self .register_parameter ("q_bias" , None )
@@ -82,8 +82,8 @@ def __init__(
8282 self .out_proj_bias = None
8383 if self .include_norm_add :
8484 if impl == "fast" :
85- self .lyr_nrm_gamma_weights = Parameter (torch .Tensor (embed_dim ))
86- self .lyr_nrm_beta_weights = Parameter (torch .Tensor (embed_dim ))
85+ self .lyr_nrm_gamma_weights = Parameter (torch .empty (embed_dim ))
86+ self .lyr_nrm_beta_weights = Parameter (torch .empty (embed_dim ))
8787 self .lyr_nrm = None
8888 else :
8989 self .register_parameter ("lyr_norm_gamma_weights" , None )
0 commit comments