Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit eb1de0e

Browse files
author
Thomas Unterthiner
committed
ENH Add 'axis' argument to sparsefuncs.mean_variance_axis
1 parent 42ca0a0 commit eb1de0e

File tree

8 files changed

+93
-33
lines changed

8 files changed

+93
-33
lines changed

doc/developers/utilities.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ Efficient Routines for Sparse Matrices
140140
The ``sklearn.utils.sparsefuncs`` cython module hosts compiled extensions to
141141
efficiently process ``scipy.sparse`` data.
142142

143-
- :func:`sparsefuncs.mean_variance_axis0`: compute the means and
144-
variances along axis 0 of a CSR matrix.
143+
- :func:`sparsefuncs.mean_variance_axis`: compute the means and
144+
variances along a specified axis of a CSR matrix.
145145
Used for normalizing the tolerance stopping criterion in
146146
:class:`sklearn.cluster.k_means_.KMeans`.
147147

sklearn/cluster/k_means_.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..metrics.pairwise import euclidean_distances
2121
from ..utils.extmath import row_norms, squared_norm
2222
from ..utils.sparsefuncs_fast import assign_rows_csr
23-
from ..utils.sparsefuncs import mean_variance_axis0
23+
from ..utils.sparsefuncs import mean_variance_axis
2424
from ..utils.fixes import astype
2525
from ..utils import check_array
2626
from ..utils import check_random_state
@@ -141,7 +141,7 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
141141
def _tolerance(X, tol):
142142
"""Return a tolerance which is independent of the dataset"""
143143
if sp.issparse(X):
144-
variances = mean_variance_axis0(X)[1]
144+
variances = mean_variance_axis(X, axis=0)[1]
145145
else:
146146
variances = np.var(X, axis=0)
147147
return np.mean(variances) * tol

sklearn/decomposition/truncated_svd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..base import BaseEstimator, TransformerMixin
1818
from ..utils import check_array, as_float_array, check_random_state
1919
from ..utils.extmath import randomized_svd, safe_sparse_dot, svd_flip
20-
from ..utils.sparsefuncs import mean_variance_axis0
20+
from ..utils.sparsefuncs import mean_variance_axis
2121

2222
__all__ = ["TruncatedSVD"]
2323

@@ -175,7 +175,7 @@ def fit_transform(self, X, y=None):
175175
X_transformed = np.dot(U, np.diag(Sigma))
176176
self.explained_variance_ = exp_var = np.var(X_transformed, axis=0)
177177
if sp.issparse(X):
178-
_, full_var = mean_variance_axis0(X)
178+
_, full_var = mean_variance_axis(X, axis=0)
179179
full_var = full_var.sum()
180180
else:
181181
full_var = np.var(X, axis=0).sum()

sklearn/linear_model/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
2828
from ..utils import as_float_array, check_array
2929
from ..utils.extmath import safe_sparse_dot
30-
from ..utils.sparsefuncs import mean_variance_axis0, inplace_column_scale
30+
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
3131

3232

3333
###
@@ -48,14 +48,14 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
4848
if fit_intercept:
4949
# we might require not to change the csr matrix sometimes
5050
# store a copy if normalize is True.
51-
# Change dtype to float64 since mean_variance_axis0 accepts
51+
# Change dtype to float64 since mean_variance_axis accepts
5252
# it that way.
5353
if sp.isspmatrix(X) and X.getformat() == 'csr':
5454
X = sp.csr_matrix(X, copy=normalize, dtype=np.float64)
5555
else:
5656
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)
5757

58-
X_mean, X_var = mean_variance_axis0(X)
58+
X_mean, X_var = mean_variance_axis(X, axis=0)
5959
if normalize:
6060
# transform variance to std in-place
6161
# XXX: currently scaled to variance=n_samples to match center_data

sklearn/preprocessing/data.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
from ..utils import warn_if_not_float
1717
from ..utils.extmath import row_norms
1818
from ..utils.fixes import combinations_with_replacement as combinations_w_r
19-
from ..utils.sparsefuncs_fast import inplace_csr_row_normalize_l1
20-
from ..utils.sparsefuncs_fast import inplace_csr_row_normalize_l2
21-
from ..utils.sparsefuncs import inplace_column_scale
22-
from ..utils.sparsefuncs import mean_variance_axis0
19+
from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1,
20+
inplace_csr_row_normalize_l2)
21+
from ..utils.sparsefuncs import (inplace_column_scale, mean_variance_axis)
2322

2423
zip = six.moves.zip
2524
map = six.moves.map
@@ -124,7 +123,7 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
124123
copy = False
125124
if copy:
126125
X = X.copy()
127-
_, var = mean_variance_axis0(X)
126+
_, var = mean_variance_axis(X, axis=0)
128127
var[var == 0.0] = 1.0
129128
inplace_column_scale(X, 1 / np.sqrt(var))
130129
else:
@@ -319,7 +318,7 @@ def fit(self, X, y=None):
319318
self.mean_ = None
320319

321320
if self.with_std:
322-
var = mean_variance_axis0(X)[1]
321+
var = mean_variance_axis(X, axis=0)[1]
323322
self.std_ = np.sqrt(var)
324323
self.std_[var == 0.0] = 1.0
325324
else:

sklearn/preprocessing/tests/test_data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sklearn.utils.testing import assert_false
1515
from sklearn.utils.testing import assert_warns
1616

17-
from sklearn.utils.sparsefuncs import mean_variance_axis0
17+
from sklearn.utils.sparsefuncs import mean_variance_axis
1818
from sklearn.preprocessing.data import _transform_selected
1919
from sklearn.preprocessing.data import Binarizer
2020
from sklearn.preprocessing.data import KernelCenterer
@@ -283,7 +283,7 @@ def test_scaler_without_centering():
283283
X_scaled.mean(axis=0), [0., -0.01, 2.24, -0.35, -0.78], 2)
284284
assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.])
285285

286-
X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis0(X_csr_scaled)
286+
X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis(X_csr_scaled, 0)
287287
assert_array_almost_equal(X_csr_scaled_mean, X_scaled.mean(axis=0))
288288
assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0))
289289

@@ -349,8 +349,8 @@ def test_scaler_int():
349349
[0., 1.109, 1.856, 21., 1.559], 2)
350350
assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.])
351351

352-
X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis0(
353-
X_csr_scaled.astype(np.float))
352+
X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis(
353+
X_csr_scaled.astype(np.float), 0)
354354
assert_array_almost_equal(X_csr_scaled_mean, X_scaled.mean(axis=0))
355355
assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0))
356356

@@ -432,7 +432,7 @@ def test_scale_function_without_centering():
432432
# Check that X has not been copied
433433
assert_true(X_scaled is not X)
434434

435-
X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis0(X_csr_scaled)
435+
X_csr_scaled_mean, X_csr_scaled_std = mean_variance_axis(X_csr_scaled, 0)
436436
assert_array_almost_equal(X_csr_scaled_mean, X_scaled.mean(axis=0))
437437
assert_array_almost_equal(X_csr_scaled_std, X_scaled.std(axis=0))
438438

sklearn/utils/sparsefuncs.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import numpy as np
77

88
from .fixes import sparse_min_max
9-
from .sparsefuncs_fast import (csr_mean_variance_axis0,
10-
csc_mean_variance_axis0)
9+
from .sparsefuncs_fast import csr_mean_variance_axis0 as _csr_mean_var_axis0
10+
from .sparsefuncs_fast import csc_mean_variance_axis0 as _csc_mean_var_axis0
1111

1212

1313
def _raise_typeerror(X):
@@ -53,14 +53,17 @@ def inplace_csr_row_scale(X, scale):
5353
X.data *= np.repeat(scale, np.diff(X.indptr))
5454

5555

56-
def mean_variance_axis0(X):
56+
def mean_variance_axis(X, axis):
5757
"""Compute mean and variance along axis 0 on a CSR or CSC matrix
5858
5959
Parameters
6060
----------
6161
X: CSR or CSC sparse matrix, shape (n_samples, n_features)
6262
Input data.
6363
64+
axis: int (either 0 or 1)
65+
Axis along which the axis should be computed.
66+
6467
Returns
6568
-------
6669
@@ -71,10 +74,20 @@ def mean_variance_axis0(X):
7174
Feature-wise variances
7275
7376
"""
77+
if axis != 0 and axis != 1:
78+
raise ValueError(
79+
"Unknown axis value: %d. Use 0 for rows, or 1 for columns" % axis)
80+
7481
if isinstance(X, sp.csr_matrix):
75-
return csr_mean_variance_axis0(X)
82+
if axis == 0:
83+
return _csr_mean_var_axis0(X)
84+
else:
85+
return _csc_mean_var_axis0(X.T)
7686
elif isinstance(X, sp.csc_matrix):
77-
return csc_mean_variance_axis0(X)
87+
if axis == 0:
88+
return _csc_mean_var_axis0(X)
89+
else:
90+
return _csr_mean_var_axis0(X.T)
7891
else:
7992
_raise_typeerror(X)
8093

@@ -258,13 +271,16 @@ def inplace_swap_column(X, m, n):
258271

259272

260273
def min_max_axis(X, axis):
261-
"""Compute minimum and maximum along axis 0 on a CSR or CSC matrix
274+
"""Compute minimum and maximum along an axis on a CSR or CSC matrix
262275
263276
Parameters
264277
----------
265278
X : CSR or CSC sparse matrix, shape (n_samples, n_features)
266279
Input data.
267280
281+
axis: int (either 0 or 1)
282+
Axis along which the axis should be computed.
283+
268284
Returns
269285
-------
270286

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from numpy.testing import assert_array_almost_equal, assert_array_equal
66

77
from sklearn.datasets import make_classification
8-
from sklearn.utils.sparsefuncs import (mean_variance_axis0,
8+
from sklearn.utils.sparsefuncs import (mean_variance_axis,
99
inplace_column_scale,
1010
inplace_row_scale,
1111
inplace_swap_row, inplace_swap_column,
@@ -26,27 +26,72 @@ def test_mean_variance_axis0():
2626
X[1, 0] = 0
2727
X_csr = sp.csr_matrix(X_lil)
2828

29-
X_means, X_vars = mean_variance_axis0(X_csr)
29+
X_means, X_vars = mean_variance_axis(X_csr, axis=0)
3030
assert_array_almost_equal(X_means, np.mean(X, axis=0))
3131
assert_array_almost_equal(X_vars, np.var(X, axis=0))
3232

3333
X_csc = sp.csc_matrix(X_lil)
34-
X_means, X_vars = mean_variance_axis0(X_csc)
34+
X_means, X_vars = mean_variance_axis(X_csc, axis=0)
3535

3636
assert_array_almost_equal(X_means, np.mean(X, axis=0))
3737
assert_array_almost_equal(X_vars, np.var(X, axis=0))
38-
assert_raises(TypeError, mean_variance_axis0, X_lil)
38+
assert_raises(TypeError, mean_variance_axis, X_lil, axis=0)
3939

4040
X = X.astype(np.float32)
4141
X_csr = X_csr.astype(np.float32)
4242
X_csc = X_csr.astype(np.float32)
43-
X_means, X_vars = mean_variance_axis0(X_csr)
43+
X_means, X_vars = mean_variance_axis(X_csr, axis=0)
4444
assert_array_almost_equal(X_means, np.mean(X, axis=0))
4545
assert_array_almost_equal(X_vars, np.var(X, axis=0))
46-
X_means, X_vars = mean_variance_axis0(X_csc)
46+
X_means, X_vars = mean_variance_axis(X_csc, axis=0)
4747
assert_array_almost_equal(X_means, np.mean(X, axis=0))
4848
assert_array_almost_equal(X_vars, np.var(X, axis=0))
49-
assert_raises(TypeError, mean_variance_axis0, X_lil)
49+
assert_raises(TypeError, mean_variance_axis, X_lil, axis=0)
50+
51+
52+
def test_mean_variance_illegal_axis():
53+
X, _ = make_classification(5, 4, random_state=0)
54+
# Sparsify the array a little bit
55+
X[0, 0] = 0
56+
X[2, 1] = 0
57+
X[4, 3] = 0
58+
X_csr = sp.csr_matrix(X)
59+
assert_raises(ValueError, mean_variance_axis, X_csr, axis=-3)
60+
assert_raises(ValueError, mean_variance_axis, X_csr, axis=2)
61+
assert_raises(ValueError, mean_variance_axis, X_csr, axis=-1)
62+
63+
def test_mean_variance_axis1():
64+
X, _ = make_classification(5, 4, random_state=0)
65+
# Sparsify the array a little bit
66+
X[0, 0] = 0
67+
X[2, 1] = 0
68+
X[4, 3] = 0
69+
X_lil = sp.lil_matrix(X)
70+
X_lil[1, 0] = 0
71+
X[1, 0] = 0
72+
X_csr = sp.csr_matrix(X_lil)
73+
74+
X_means, X_vars = mean_variance_axis(X_csr, axis=1)
75+
assert_array_almost_equal(X_means, np.mean(X, axis=1))
76+
assert_array_almost_equal(X_vars, np.var(X, axis=1))
77+
78+
X_csc = sp.csc_matrix(X_lil)
79+
X_means, X_vars = mean_variance_axis(X_csc, axis=1)
80+
81+
assert_array_almost_equal(X_means, np.mean(X, axis=1))
82+
assert_array_almost_equal(X_vars, np.var(X, axis=1))
83+
assert_raises(TypeError, mean_variance_axis, X_lil, axis=1)
84+
85+
X = X.astype(np.float32)
86+
X_csr = X_csr.astype(np.float32)
87+
X_csc = X_csr.astype(np.float32)
88+
X_means, X_vars = mean_variance_axis(X_csr, axis=1)
89+
assert_array_almost_equal(X_means, np.mean(X, axis=1))
90+
assert_array_almost_equal(X_vars, np.var(X, axis=1))
91+
X_means, X_vars = mean_variance_axis(X_csc, axis=1)
92+
assert_array_almost_equal(X_means, np.mean(X, axis=1))
93+
assert_array_almost_equal(X_vars, np.var(X, axis=1))
94+
assert_raises(TypeError, mean_variance_axis, X_lil, axis=1)
5095

5196

5297
def test_densify_rows():

0 commit comments

Comments
 (0)