diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 0897f331ebda0..13734bb828660 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -181,3 +181,10 @@ Cluster - Deprecate ``pooling_func`` unused parameter in :class:`cluster.AgglomerativeClustering`. :issue:`9875` by :user:`Kumar Ashutosh `. + +Changes to estimator checks +--------------------------- + +- Allow tests in :func:`estimator_checks.check_estimator` to test functions + that accept pairwise data. + :issue:`9701` by :user:`Kyle Johnson ` diff --git a/sklearn/base.py b/sklearn/base.py index 81c7e5dae7bcc..6f59cea3c7ab7 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -551,7 +551,6 @@ def is_classifier(estimator): def is_regressor(estimator): """Returns True if the given estimator is (probably) a regressor. - Parameters ---------- estimator : object diff --git a/sklearn/neighbors/regression.py b/sklearn/neighbors/regression.py index bd2ffb9b82489..b13f16cfd399e 100644 --- a/sklearn/neighbors/regression.py +++ b/sklearn/neighbors/regression.py @@ -9,6 +9,7 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np +from scipy.sparse import issparse from .base import _get_weights, _check_weights, NeighborsBase, KNeighborsMixin from .base import RadiusNeighborsMixin, SupervisedFloatMixin @@ -139,6 +140,11 @@ def predict(self, X): y : array of int, shape = [n_samples] or [n_samples, n_outputs] Target values """ + if issparse(X) and self.metric == 'precomputed': + raise ValueError( + "Sparse matrices not supported for prediction with " + "precomputed kernels. Densify your matrix." + ) X = check_array(X, accept_sparse='csr') neigh_dist, neigh_ind = self.kneighbors(X) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 052c83c71d2e7..ceb53412018b8 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2,7 +2,7 @@ import numpy as np from scipy.sparse import (bsr_matrix, coo_matrix, csc_matrix, csr_matrix, - dok_matrix, lil_matrix) + dok_matrix, lil_matrix, issparse) from sklearn import metrics from sklearn import neighbors, datasets @@ -731,10 +731,22 @@ def test_kneighbors_regressor_sparse(n_samples=40, knn = neighbors.KNeighborsRegressor(n_neighbors=n_neighbors, algorithm='auto') knn.fit(sparsemat(X), y) + + knn_pre = neighbors.KNeighborsRegressor(n_neighbors=n_neighbors, + metric='precomputed') + knn_pre.fit(pairwise_distances(X, metric='euclidean'), y) + for sparsev in SPARSE_OR_DENSE: X2 = sparsev(X) assert_true(np.mean(knn.predict(X2).round() == y) > 0.95) + X2_pre = sparsev(pairwise_distances(X, metric='euclidean')) + if issparse(sparsev(X2_pre)): + assert_raises(ValueError, knn_pre.predict, X2_pre) + else: + assert_true( + np.mean(knn_pre.predict(X2_pre).round() == y) > 0.95) + def test_neighbors_iris(): # Sanity checks on the iris dataset diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index fdbecc358be35..40fcb1fdd069f 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -37,6 +37,7 @@ from sklearn.base import (clone, TransformerMixin, ClusterMixin, BaseEstimator, is_classifier, is_regressor) + from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score from sklearn.random_projection import BaseRandomProjection @@ -48,6 +49,8 @@ from sklearn.exceptions import DataConversionWarning from sklearn.exceptions import SkipTestWarning from sklearn.model_selection import train_test_split +from sklearn.metrics.pairwise import (rbf_kernel, linear_kernel, + pairwise_distances) from sklearn.utils import shuffle from sklearn.utils.fixes import signature @@ -355,10 +358,56 @@ def _is_32bit(): return struct.calcsize('P') * 8 == 32 +def _is_pairwise(estimator): + """Returns True if estimator has a _pairwise attribute set to True. + + Parameters + ---------- + estimator : object + Estimator object to test. + + Returns + ------- + out : bool + True if _pairwise is set to True and False otherwise. + """ + return bool(getattr(estimator, "_pairwise", False)) + + +def _is_pairwise_metric(estimator): + """Returns True if estimator accepts pairwise metric. + + Parameters + ---------- + estimator : object + Estimator object to test. + + Returns + ------- + out : bool + True if _pairwise is set to True and False otherwise. + """ + metric = getattr(estimator, "metric", None) + + return bool(metric == 'precomputed') + + +def pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel): + + if _is_pairwise_metric(estimator): + return pairwise_distances(X, metric='euclidean') + if _is_pairwise(estimator): + return kernel(X, X) + + return X + + def check_estimator_sparse_data(name, estimator_orig): + rng = np.random.RandomState(0) X = rng.rand(40, 10) X[X < .8] = 0 + X = pairwise_estimator_convert_X(X, estimator_orig) X_csr = sparse.csr_matrix(X) y = (4 * rng.rand(40)).astype(np.int) # catch deprecation warnings @@ -383,8 +432,8 @@ def check_estimator_sparse_data(name, estimator_orig): if hasattr(estimator, 'predict_proba'): probs = estimator.predict_proba(X) assert_equal(probs.shape, (X.shape[0], 4)) - except TypeError as e: - if 'sparse' not in repr(e): + except (TypeError, ValueError) as e: + if 'sparse' not in repr(e).lower(): print("Estimator %s doesn't seem to fail gracefully on " "sparse data: error message state explicitly that " "sparse input is not supported if this is not the case." @@ -405,7 +454,8 @@ def check_sample_weights_pandas_series(name, estimator_orig): if has_fit_parameter(estimator, "sample_weight"): try: import pandas as pd - X = pd.DataFrame([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]]) + X = np.array([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]]) + X = pd.DataFrame(pairwise_estimator_convert_X(X, estimator_orig)) y = pd.Series([1, 1, 1, 2, 2, 2]) weights = pd.Series([1] * 6) try: @@ -426,7 +476,8 @@ def check_sample_weights_list(name, estimator_orig): if has_fit_parameter(estimator_orig, "sample_weight"): estimator = clone(estimator_orig) rnd = np.random.RandomState(0) - X = rnd.uniform(size=(10, 3)) + X = pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), + estimator_orig) y = np.arange(10) % 3 y = multioutput_estimator_convert_y_2d(estimator, y) sample_weight = [3] * 10 @@ -438,7 +489,8 @@ def check_sample_weights_list(name, estimator_orig): def check_dtype_object(name, estimator_orig): # check that estimators treat dtype object as numeric if possible rng = np.random.RandomState(0) - X = rng.rand(40, 10).astype(object) + X = pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig) + X = X.astype(object) y = (X[:, 0] * 4).astype(np.int) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -485,6 +537,8 @@ def check_dict_unchanged(name, estimator_orig): else: X = 2 * rnd.uniform(size=(20, 3)) + X = pairwise_estimator_convert_X(X, estimator_orig) + y = X[:, 0].astype(np.int) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -522,6 +576,7 @@ def check_dont_overwrite_parameters(name, estimator_orig): estimator = clone(estimator_orig) rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20, 3)) + X = pairwise_estimator_convert_X(X, estimator_orig) y = X[:, 0].astype(np.int) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -568,6 +623,7 @@ def check_fit2d_predict1d(name, estimator_orig): # check by fitting a 2d array and predicting with a 1d array rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(20, 3)) + X = pairwise_estimator_convert_X(X, estimator_orig) y = X[:, 0].astype(np.int) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -621,6 +677,7 @@ def check_fit2d_1feature(name, estimator_orig): # informative message rnd = np.random.RandomState(0) X = 3 * rnd.uniform(size=(10, 1)) + X = pairwise_estimator_convert_X(X, estimator_orig) y = X[:, 0].astype(np.int) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -793,6 +850,7 @@ def check_pipeline_consistency(name, estimator_orig): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X -= X.min() + X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) set_random_state(estimator) @@ -817,6 +875,7 @@ def check_fit_score_takes_y(name, estimator_orig): # in fit and score so they can be used in pipelines rnd = np.random.RandomState(0) X = rnd.uniform(size=(10, 3)) + X = pairwise_estimator_convert_X(X, estimator_orig) y = np.arange(10) % 3 estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -842,6 +901,7 @@ def check_fit_score_takes_y(name, estimator_orig): def check_estimators_dtypes(name, estimator_orig): rnd = np.random.RandomState(0) X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32) + X_train_32 = pairwise_estimator_convert_X(X_train_32, estimator_orig) X_train_64 = X_train_32.astype(np.float64) X_train_int_64 = X_train_32.astype(np.int64) X_train_int_32 = X_train_32.astype(np.int32) @@ -887,7 +947,8 @@ def check_estimators_empty_data_messages(name, estimator_orig): def check_estimators_nan_inf(name, estimator_orig): # Checks that Estimator X's do not contain NaN or inf. rnd = np.random.RandomState(0) - X_train_finite = rnd.uniform(size=(10, 3)) + X_train_finite = pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), + estimator_orig) X_train_nan = rnd.uniform(size=(10, 3)) X_train_nan[0, 0] = np.nan X_train_inf = rnd.uniform(size=(10, 3)) @@ -964,6 +1025,7 @@ def check_estimators_pickle(name, estimator_orig): # some estimators can't do features less than 0 X -= X.min() + X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) @@ -1138,6 +1200,7 @@ def check_classifiers_train(name, classifier_orig): classifier = clone(classifier_orig) if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']: X -= X.min() + X = pairwise_estimator_convert_X(X, classifier_orig) set_random_state(classifier) # raises error on malformed input for fit with assert_raises(ValueError, msg="The classifer {} does not" @@ -1159,11 +1222,18 @@ def check_classifiers_train(name, classifier_orig): assert_greater(accuracy_score(y, y_pred), 0.83) # raises error on malformed input for predict - with assert_raises(ValueError, msg="The classifier {} does not" - " raise an error when the number of features " - "in predict is different from the number of" - " features in fit.".format(name)): - classifier.predict(X.T) + if _is_pairwise(classifier): + with assert_raises(ValueError, msg="The classifier {} does not" + " raise an error when shape of X" + "in predict is not equal to (n_test_samples," + "n_training_samples)".format(name)): + classifier.predict(X.reshape(-1, 1)) + else: + with assert_raises(ValueError, msg="The classifier {} does not" + " raise an error when the number of features " + "in predict is different from the number of" + " features in fit.".format(name)): + classifier.predict(X.T) if hasattr(classifier, "decision_function"): try: # decision_function agrees with predict @@ -1179,12 +1249,21 @@ def check_classifiers_train(name, classifier_orig): assert_array_equal(np.argmax(decision, axis=1), y_pred) # raises error on malformed input for decision_function - with assert_raises(ValueError, msg="The classifier {} does" - " not raise an error when the number of " - "features in decision_function is " - "different from the number of features" - " in fit.".format(name)): - classifier.decision_function(X.T) + if _is_pairwise(classifier): + with assert_raises(ValueError, msg="The classifier {} does" + " not raise an error when the " + "shape of X in decision_function is " + "not equal to (n_test_samples, " + "n_training_samples) in fit." + .format(name)): + classifier.decision_function(X.reshape(-1, 1)) + else: + with assert_raises(ValueError, msg="The classifier {} does" + " not raise an error when the number " + "of features in decision_function is " + "different from the number of features" + " in fit.".format(name)): + classifier.decision_function(X.T) except NotImplementedError: pass if hasattr(classifier, "predict_proba"): @@ -1195,11 +1274,20 @@ def check_classifiers_train(name, classifier_orig): # check that probas for all classes sum to one assert_allclose(np.sum(y_prob, axis=1), np.ones(n_samples)) # raises error on malformed input for predict_proba - with assert_raises(ValueError, msg="The classifier {} does not" - " raise an error when the number of features " - "in predict_proba is different from the number " - "of features in fit.".format(name)): - classifier.predict_proba(X.T) + if _is_pairwise(classifier_orig): + with assert_raises(ValueError, msg="The classifier {} does not" + " raise an error when the shape of X" + "in predict_proba is not equal to " + "(n_test_samples, n_training_samples)." + .format(name)): + classifier.predict_proba(X.reshape(-1, 1)) + else: + with assert_raises(ValueError, msg="The classifier {} does not" + " raise an error when the number of " + "features in predict_proba is different " + "from the number of features in fit." + .format(name)): + classifier.predict_proba(X.T) if hasattr(classifier, "predict_log_proba"): # predict_log_proba is a transformation of predict_proba y_log_prob = classifier.predict_log_proba(X) @@ -1213,6 +1301,7 @@ def check_estimators_fit_returns_self(name, estimator_orig): X, y = make_blobs(random_state=0, n_samples=9, n_features=4) # some want non-negative input X -= X.min() + X = pairwise_estimator_convert_X(X, estimator_orig) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -1260,7 +1349,7 @@ def check_supervised_y_2d(name, estimator_orig): # These only work on 2d, so this test makes no sense return rnd = np.random.RandomState(0) - X = rnd.uniform(size=(10, 3)) + X = pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), estimator_orig) y = np.arange(10) % 3 estimator = clone(estimator_orig) set_random_state(estimator) @@ -1294,6 +1383,7 @@ def check_classifiers_classes(name, classifier_orig): # We need to make sure that we have non negative data, for things # like NMF X -= X.min() - .1 + X = pairwise_estimator_convert_X(X, classifier_orig) y_names = np.array(["one", "two", "three"])[y] for y_names in [y_names, y_names.astype('O')]: @@ -1325,7 +1415,7 @@ def check_classifiers_classes(name, classifier_orig): @ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_regressors_int(name, regressor_orig): X, _ = _boston_subset() - X = X[:50] + X = pairwise_estimator_convert_X(X[:50], regressor_orig) rnd = np.random.RandomState(0) y = rnd.randint(3, size=X.shape[0]) y = multioutput_estimator_convert_y_2d(regressor_orig, y) @@ -1353,6 +1443,7 @@ def check_regressors_int(name, regressor_orig): @ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_regressors_train(name, regressor_orig): X, y = _boston_subset() + X = pairwise_estimator_convert_X(X, regressor_orig) y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled y = y.ravel() regressor = clone(regressor_orig) @@ -1429,6 +1520,12 @@ def check_class_weight_classifiers(name, classifier_orig): X, y = make_blobs(centers=n_centers, random_state=0, cluster_std=20) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0) + + # can't use gram_if_pairwise() here, setting up gram matrix manually + if _is_pairwise(classifier_orig): + X_test = rbf_kernel(X_test, X_train) + X_train = rbf_kernel(X_train, X_train) + n_centers = len(np.unique(y_train)) if n_centers == 2: @@ -1512,6 +1609,7 @@ def check_estimators_overwrite_params(name, estimator_orig): X, y = make_blobs(random_state=0, n_samples=9) # some want non-negative input X -= X.min() + X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel) estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) @@ -1586,6 +1684,7 @@ def check_sparsify_coefficients(name, estimator_orig): @ignore_warnings(category=DeprecationWarning) def check_classifier_data_not_an_array(name, estimator_orig): X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1]]) + X = pairwise_estimator_convert_X(X, estimator_orig) y = [1, 1, 1, 2, 2, 2] y = multioutput_estimator_convert_y_2d(estimator_orig, y) check_estimators_data_not_an_array(name, estimator_orig, X, y) @@ -1594,6 +1693,7 @@ def check_classifier_data_not_an_array(name, estimator_orig): @ignore_warnings(category=DeprecationWarning) def check_regressor_data_not_an_array(name, estimator_orig): X, y = _boston_subset(n_samples=50) + X = pairwise_estimator_convert_X(X, estimator_orig) y = multioutput_estimator_convert_y_2d(estimator_orig, y) check_estimators_data_not_an_array(name, estimator_orig, X, y) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 1b3a1ea7e597a..2323f8a634eb2 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -18,6 +18,8 @@ from sklearn.cluster import MiniBatchKMeans from sklearn.decomposition import NMF from sklearn.linear_model import MultiTaskElasticNet +from sklearn.svm import SVC +from sklearn.neighbors import KNeighborsRegressor from sklearn.utils.validation import check_X_y, check_array @@ -251,3 +253,16 @@ def __init__(self): check_no_fit_attributes_set_in_init, 'estimator_name', NonConformantEstimator) + + +def test_check_estimator_pairwise(): + # check that check_estimator() works on estimator with _pairwise + # kernel or metric + + # test precomputed kernel + est = SVC(kernel='precomputed') + check_estimator(est) + + # test precomputed metric + est = KNeighborsRegressor(metric='precomputed') + check_estimator(est)