Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ba027dd

Browse files
authored
replace torch.Tensor with torch.empty (#1578)
* replace torch.Tensor with torch.empty * nit * nit
1 parent 0145d69 commit ba027dd

9 files changed

Lines changed: 43 additions & 43 deletions

File tree

apex/RNN/RNNBackend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,17 +254,17 @@ def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_stat
254254
self.gate_size = gate_multiplier * self.hidden_size
255255
self.n_hidden_states = n_hidden_states
256256

257-
self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))
258-
self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))
257+
self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size))
258+
self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size))
259259

260260
#Check if there's recurrent projection
261261
if(self.output_size != self.hidden_size):
262-
self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))
262+
self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size))
263263

264264
self.b_ih = self.b_hh = None
265265
if self.bias:
266-
self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))
267-
self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))
266+
self.b_ih = nn.Parameter(torch.empty(self.gate_size))
267+
self.b_hh = nn.Parameter(torch.empty(self.gate_size))
268268

269269
#hidden states for forward
270270
self.hidden = [ None for states in range(self.n_hidden_states)]

apex/RNN/cells.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def __init__(self, input_size, hidden_size, bias = False, output_size = None):
1818
gate_multiplier = 4
1919
super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
2020

21-
self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))
22-
self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))
21+
self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size))
22+
self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size))
2323

2424
self.reset_parameters()
2525

apex/amp/compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ def variable_is_tensor():
66
return isinstance(v, torch.Tensor)
77

88
def tensor_is_variable():
9-
x = torch.Tensor()
9+
x = torch.empty()
1010
return type(x) == torch.autograd.Variable
1111

1212
# False for post-0.4
1313
def tensor_is_float_tensor():
14-
x = torch.Tensor()
14+
x = torch.empty()
1515
return type(x) == torch.FloatTensor
1616

1717
# Akin to `torch.is_tensor`, but returns True for Variable

apex/contrib/layer_norm/layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class FastLayerNorm(torch.nn.Module):
4141
def __init__(self, hidden_size, eps=1e-5):
4242
super().__init__()
4343
self.epsilon = eps
44-
self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))
45-
self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))
44+
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
45+
self.bias = torch.nn.Parameter(torch.empty(hidden_size))
4646
self.reset_parameters()
4747

4848
def reset_parameters(self):

apex/contrib/multihead_attn/encdec_multihead_attn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_a
3636
self.impl = impl
3737
self.scaling = self.head_dim ** -0.5
3838

39-
self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
40-
self.in_proj_weight_kv = Parameter(torch.Tensor(2 * embed_dim, embed_dim))
41-
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
39+
self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim))
40+
self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim))
41+
self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
4242
if self.bias:
4343
assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
44-
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
45-
self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
46-
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
44+
self.in_proj_bias_q = Parameter(torch.empty(embed_dim))
45+
self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim))
46+
self.out_proj_bias = Parameter(torch.empty(embed_dim))
4747
else:
4848
self.register_parameter("in_proj_bias_q", None)
4949
self.register_parameter("in_proj_bias_kv", None)
@@ -52,8 +52,8 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_a
5252
self.out_proj_bias = None
5353
if self.include_norm_add:
5454
if impl == "fast":
55-
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
56-
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
55+
self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
56+
self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
5757
self.lyr_nrm = None
5858
else:
5959
self.register_parameter("lyr_norm_gamma_weights", None)

apex/contrib/multihead_attn/self_multihead_attn.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@ def __init__(
5353
impl == "fast" and bias
5454
), "additive mask not supported for fast mode without bias"
5555
if separate_qkv_params:
56-
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
57-
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
58-
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
56+
self.q_weight = Parameter(torch.empty(embed_dim, embed_dim))
57+
self.k_weight = Parameter(torch.empty(embed_dim, embed_dim))
58+
self.v_weight = Parameter(torch.empty(embed_dim, embed_dim))
5959
else:
60-
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
61-
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
60+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
61+
self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
6262
if self.bias:
6363
if separate_qkv_params:
64-
self.q_bias = Parameter(torch.Tensor(embed_dim))
65-
self.k_bias = Parameter(torch.Tensor(embed_dim))
66-
self.v_bias = Parameter(torch.Tensor(embed_dim))
64+
self.q_bias = Parameter(torch.empty(embed_dim))
65+
self.k_bias = Parameter(torch.empty(embed_dim))
66+
self.v_bias = Parameter(torch.empty(embed_dim))
6767
else:
68-
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
69-
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
68+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
69+
self.out_proj_bias = Parameter(torch.empty(embed_dim))
7070
else:
7171
if separate_qkv_params:
7272
self.register_parameter("q_bias", None)
@@ -82,8 +82,8 @@ def __init__(
8282
self.out_proj_bias = None
8383
if self.include_norm_add:
8484
if impl == "fast":
85-
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
86-
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
85+
self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
86+
self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
8787
self.lyr_nrm = None
8888
else:
8989
self.register_parameter("lyr_norm_gamma_weights", None)

apex/contrib/sparsity/sparse_masklib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def compute_valid_1d_patterns(m,n):
2929
if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
3030
patterns = torch.zeros(m)
3131
patterns[:n] = 1
32-
valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
32+
valid_patterns = torch.empty(list(set(permutations(patterns.tolist()))))
3333
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
3434
return valid_patterns
3535

@@ -109,10 +109,10 @@ def compute_valid_2d_patterns(m,n):
109109
patterns[:n] = 1
110110
patterns = list(set(permutations(patterns.tolist())))
111111
patterns = patterns + patterns
112-
patterns = torch.Tensor(list(set(permutations(patterns,m))))
112+
patterns = torch.empty(list(set(permutations(patterns,m))))
113113

114114
valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)
115-
valid_patterns = torch.Tensor(valid.shape[0],m,m)
115+
valid_patterns = torch.empty(valid.shape[0],m,m)
116116
valid_patterns[:] = patterns[valid[:]]
117117

118118
if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns

apex/fused_dense/fused_dense.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def __init__(self, in_features, out_features, bias=True):
6666
super(FusedDense, self).__init__()
6767
self.in_features = in_features
6868
self.out_features = out_features
69-
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
69+
self.weight = nn.Parameter(torch.empty(out_features, in_features))
7070
if bias:
71-
self.bias = nn.Parameter(torch.Tensor(out_features))
71+
self.bias = nn.Parameter(torch.empty(out_features))
7272
else:
7373
#assert False, "no-bias option not added yet"
7474
self.register_parameter('bias', None)
@@ -86,10 +86,10 @@ def __init__(self, in_features, intermediate_features, out_features, bias=True):
8686
self.in_features = in_features
8787
self.intermediate_features = intermediate_features
8888
self.out_features = out_features
89-
self.weight1 = nn.Parameter(torch.Tensor(intermediate_features, in_features))
90-
self.bias1 = nn.Parameter(torch.Tensor(intermediate_features))
91-
self.weight2 = nn.Parameter(torch.Tensor(out_features, intermediate_features))
92-
self.bias2 = nn.Parameter(torch.Tensor(out_features))
89+
self.weight1 = nn.Parameter(torch.empty(intermediate_features, in_features))
90+
self.bias1 = nn.Parameter(torch.empty(intermediate_features))
91+
self.weight2 = nn.Parameter(torch.empty(out_features, intermediate_features))
92+
self.bias2 = nn.Parameter(torch.empty(out_features))
9393

9494
def forward(self, input):
9595
return _fused_dense_gelu_dense(input, self.weight1, self.bias1, self.weight2, self.bias2)

apex/normalization/fused_layer_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
273273
self.eps = eps
274274
self.elementwise_affine = elementwise_affine
275275
if self.elementwise_affine:
276-
self.weight = Parameter(torch.Tensor(*normalized_shape))
277-
self.bias = Parameter(torch.Tensor(*normalized_shape))
276+
self.weight = Parameter(torch.empty(*normalized_shape))
277+
self.bias = Parameter(torch.empty(*normalized_shape))
278278
else:
279279
self.register_parameter("weight", None)
280280
self.register_parameter("bias", None)
@@ -369,7 +369,7 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
369369
self.eps = eps
370370
self.elementwise_affine = elementwise_affine
371371
if self.elementwise_affine:
372-
self.weight = Parameter(torch.Tensor(*normalized_shape))
372+
self.weight = Parameter(torch.empty(*normalized_shape))
373373
else:
374374
self.register_parameter("weight", None)
375375
self.reset_parameters()

0 commit comments

Comments
 (0)