From 62695fd0abad7f6bc52ea6b0992a3f588a864726 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 1 Aug 2019 17:18:20 +0200 Subject: [PATCH 1/3] Initial version of the faster implementation of scipy.stats.mode --- sklearn/neighbors/classification.py | 4 +-- sklearn/utils/extmath.py | 39 +++++++++++++++++++++++++++++ sklearn/utils/tests/test_extmath.py | 30 ++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index f0fd0b084365a..53a472086198b 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -10,7 +10,7 @@ import numpy as np from scipy import stats -from ..utils.extmath import weighted_mode +from ..utils.extmath import weighted_mode, _fast_mode from .base import \ _check_weights, _get_weights, \ @@ -180,7 +180,7 @@ def predict(self, X): y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - mode, _ = stats.mode(_y[neigh_ind, k], axis=1) + mode = _fast_mode(_y[neigh_ind, k], axis=1) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index fcb03b0cecddd..13b467aef2a8c 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -357,6 +357,45 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto', return U[:, :n_components], s[:n_components], V[:n_components, :] +def _fast_mode(x, axis=1): + """Returns a faster equivalent for scipy.mode + + This is only implemented for positive integer data. + + Parameters + ---------- + x : array_like, shape (n_samples, n_components) + n-dimensional array of which to find mode(s). + axis : int, optional + Axis along which to operate. Default is 1. + Only axis=1 is supported. + + Returns + ------- + mode: ndarray, shape=(n_samples) + index of the mode + + Examples + -------- + >>> x = np.array([[0, 1, 1], [2, 0, 2]]) + >>> _fast_mode(x, axis=1) + array([1, 2]) + """ + if not hasattr(x, "__array__") or x.dtype.kind != 'i' or x.ndim != 2: + raise ValueError('_fast_mode is only implemented for 2D integer ' + 'arrays!') + data = np.ones(x.shape, dtype=np.int).ravel() + indices = x.ravel() + indptr = np.arange(x.shape[0]+1)*x.shape[1] + # we use the fact that data for repeated indices is summed when + # creating sparse arrays. The index with highest value is then the mode + if axis != 1: + raise ValueError('Only axis=1 is supported.') + z = sparse.csr_matrix((data, indices, indptr), + shape=(x.shape[0], x.max() + 1)) + return np.asarray(np.argmax(z, axis=1)) + + def weighted_mode(a, w, axis=0): """Returns an array of the weighted modal (most common) value in a diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 2da6e5f5e9943..3ebe02c9b8bc4 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -31,6 +31,7 @@ from sklearn.utils.extmath import _deterministic_vector_sign_flip from sklearn.utils.extmath import softmax from sklearn.utils.extmath import stable_cumsum +from sklearn.utils.extmath import _fast_mode from sklearn.datasets.samples_generator import make_low_rank_matrix @@ -640,3 +641,32 @@ def test_stable_cumsum(): assert_array_equal(stable_cumsum(A, axis=0), np.cumsum(A, axis=0)) assert_array_equal(stable_cumsum(A, axis=1), np.cumsum(A, axis=1)) assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2)) + + +class TestFastMode(): + + def test_scipy_stats_axis_1(self): + rng = np.random.RandomState(0) + + X = rng.randint(10, size=(100, 20)) + mode_ref, _ = stats.mode(X, axis=1) + mode = _fast_mode(X, axis=1) + assert_array_equal(mode, mode_ref) + + @pytest.mark.parametrize( + 'x', + [np.ones((10, 10), dtype=np.float), 1, np.ones(5, dtype=np.int)], + ids=['array_float64', 'int', '1D-array']) + def test_input_validation(self, x): + with pytest.raises(ValueError, + match='only implemented for 2D integer arrays'): + _fast_mode(x) + + def test_ties(self): + # Check that ties are resolved in the same way as in stats.mode + X = np.ones((6, 9), dtype=np.int) + X[:, 3:] = 2 + X[:, 6:] = 3 + mode_ref, _ = stats.mode(X, axis=1) + mode = _fast_mode(X, axis=1) + assert_array_equal(mode, mode_ref) From 6bd2a08568d01f6853374f696365fd08ae7d700c Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 1 Aug 2019 17:59:32 +0200 Subject: [PATCH 2/3] Also support weights --- sklearn/neighbors/classification.py | 6 +---- sklearn/utils/extmath.py | 37 +++++++++++++++++++++++++---- sklearn/utils/tests/test_extmath.py | 34 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index 53a472086198b..990b08e08bd1d 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -179,11 +179,7 @@ def predict(self, X): y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): - if weights is None: - mode = _fast_mode(_y[neigh_ind, k], axis=1) - else: - mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) - + mode = _fast_mode(_y[neigh_ind, k], weights=weights, axis=1) mode = np.asarray(mode.ravel(), dtype=np.intp) y_pred[:, k] = classes_k.take(mode) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 13b467aef2a8c..184454c44382a 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -357,7 +357,7 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto', return U[:, :n_components], s[:n_components], V[:n_components, :] -def _fast_mode(x, axis=1): +def _fast_mode(x, weights=None, axis=1): """Returns a faster equivalent for scipy.mode This is only implemented for positive integer data. @@ -366,6 +366,8 @@ def _fast_mode(x, axis=1): ---------- x : array_like, shape (n_samples, n_components) n-dimensional array of which to find mode(s). + w : array_like, shape (n_samples, n_components) + n-dimensional array of weights for each value axis : int, optional Axis along which to operate. Default is 1. Only axis=1 is supported. @@ -379,13 +381,40 @@ def _fast_mode(x, axis=1): -------- >>> x = np.array([[0, 1, 1], [2, 0, 2]]) >>> _fast_mode(x, axis=1) - array([1, 2]) + array([[1], [2]]) + + Next we illustrate weighted mode calculations + + >>> x = np.array([[4, 1, 4, 2, 4, 2]]) + >>> weights = np.array([[1, 1, 1, 1, 1, 1]]) + >>> _fast_mode(x, weights) + array([[4]]) + + The value 4 appears three times: with uniform weights, the result is + simply the mode of the distribution. + + >>> weights = np.array([[1, 3, 0.5, 1.5, 1, 2]]) # deweight the 4's + >>> _fast_mode(x, weights) + array([[2]]) + + The value 2 has the highest score: it appears twice with weights of + 1.5 and 2: the sum of these is 3.5. + """ if not hasattr(x, "__array__") or x.dtype.kind != 'i' or x.ndim != 2: raise ValueError('_fast_mode is only implemented for 2D integer ' 'arrays!') - data = np.ones(x.shape, dtype=np.int).ravel() - indices = x.ravel() + if x.min() < 0: + raise ValueError('only positive data is supported.') + + if weights is None: + data = np.ones(x.shape, dtype=np.int).ravel() + else: + if x.shape != weights.shape: + raise ValueError("x.shape {} != weights.shape {}" + .format(x.shape, weights.shape)) + data = np.ascontiguousarray(weights).ravel() + indices = np.ascontiguousarray(x).ravel() indptr = np.arange(x.shape[0]+1)*x.shape[1] # we use the fact that data for repeated indices is summed when # creating sparse arrays. The index with highest value is then the mode diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 3ebe02c9b8bc4..90de0933ddcaf 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -662,6 +662,12 @@ def test_input_validation(self, x): match='only implemented for 2D integer arrays'): _fast_mode(x) + def test_negative_values(self): + x = - np.ones((10, 10), dtype=np.int) + with pytest.raises(ValueError, + match="only positive data is supported"): + _fast_mode(x) + def test_ties(self): # Check that ties are resolved in the same way as in stats.mode X = np.ones((6, 9), dtype=np.int) @@ -670,3 +676,31 @@ def test_ties(self): mode_ref, _ = stats.mode(X, axis=1) mode = _fast_mode(X, axis=1) assert_array_equal(mode, mode_ref) + + def test_uniform_weights(self): + # with uniform weights, results should be identical to + # stats.mode + rng = np.random.RandomState(0) + x = rng.randint(10, size=(10, 5)) + weights = np.ones(x.shape) + + mode, _ = stats.mode(x, axis=1) + mode2 = _fast_mode(x, weights, axis=1) + + assert_array_equal(mode, mode2) + + def test_random_weights(self): + # set this up so that each row should have a weighted mode of 6, + # with a score that is easily reproduced + mode_result = 6 + + rng = np.random.RandomState(0) + x = rng.randint(mode_result, size=(100, 10)) + w = rng.random_sample(x.shape) + + x[:, :5] = mode_result + w[:, :5] += 1 + + mode = _fast_mode(x, w, axis=1) + + assert_array_equal(mode, mode_result) From c7d3286f8dc6a1b3234bed2d972458e9e2d4a99a Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 1 Aug 2019 18:17:05 +0200 Subject: [PATCH 3/3] Fixes to RadiusNeigboursClassifier --- sklearn/neighbors/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index 990b08e08bd1d..8f66ae57cba9a 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -9,8 +9,7 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np -from scipy import stats -from ..utils.extmath import weighted_mode, _fast_mode +from ..utils.extmath import _fast_mode from .base import \ _check_weights, _get_weights, \ @@ -416,11 +415,12 @@ def predict(self, X): pred_labels = np.zeros(len(neigh_ind), dtype=object) pred_labels[:] = [_y[ind, k] for ind in neigh_ind] if weights is None: - mode = np.array([stats.mode(pl)[0] + mode = np.array([_fast_mode(np.atleast_2d(pl)).ravel() for pl in pred_labels[inliers]], dtype=np.int) else: mode = np.array( - [weighted_mode(pl, w)[0] + [_fast_mode(np.atleast_2d(pl), np.atleast_2d(w), + axis=1).ravel() for (pl, w) in zip(pred_labels[inliers], weights[inliers]) ], dtype=np.int)