From b39630ee2cd063bbd297896db6253ecb212363b1 Mon Sep 17 00:00:00 2001 From: fbchow Date: Sat, 2 Nov 2019 14:56:50 -0700 Subject: [PATCH 1/4] Standardize sample weights validation in DummyClassifier (cherry picked from commit 95929b97d35099209c28653fab50802646ec41af) --- sklearn/dummy.py | 10 +++++++--- sklearn/tests/test_dummy.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 04322f0fc3bd1..54b8c3f1a6168 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -13,7 +13,8 @@ from .utils.validation import _num_samples from .utils.validation import check_array from .utils.validation import check_consistent_length -from .utils.validation import check_is_fitted, _check_sample_weight +from .utils.validation import check_is_fitted +from .utils.validation import _check_sample_weight from .utils.random import _random_choice_csc from .utils.stats import _weighted_percentile from .utils.multiclass import class_distribution @@ -156,7 +157,10 @@ def fit(self, X, y, sample_weight=None): self.n_outputs_ = y.shape[1] - check_consistent_length(X, y, sample_weight) + check_consistent_length(X, y) + + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X) if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) @@ -245,7 +249,7 @@ def predict(self, X): classes_ = [np.array([c]) for c in constant] y = _random_choice_csc(n_samples, classes_, class_prob, - self.random_state) + self.random_state) else: if self._strategy in ("most_frequent", "prior"): y = np.tile([classes_[k][class_prior_[k].argmax()] for diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 55f3abc77b0de..a032e6f3b1ad6 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -571,6 +571,33 @@ def test_classification_sample_weight(): assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2]) +def test_sample_weight_invalid(): + # Check sample weighting raises errors. + X = [[0], [0], [1]] + y = [0, 1, 0] + sample_weight = [0.1, 1., 0.1] + + clf = DummyClassifier().fit(X, y, sample_weight) + assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2]) + + sample_weight = np.random.rand(3, 1) + with pytest.raises(ValueError): + clf.fit(X, y, sample_weight=sample_weight) + + sample_weight = np.array(0) + expected_err = r"Singleton.* cannot be considered a valid collection" + with pytest.raises(TypeError, match=expected_err): + clf.fit(X, y, sample_weight=sample_weight) + + sample_weight = np.ones(4) + with pytest.raises(ValueError): + clf.fit(X, y, sample_weight=sample_weight) + + sample_weight = np.ones(2) + with pytest.raises(ValueError): + clf.fit(X, y, sample_weight=sample_weight) + + def test_constant_strategy_sparse_target(): X = [[0]] * 5 # ignored y = sp.csc_matrix(np.array([[0, 1], From caa7871df7330a1f482cc702f81a37d2dd4b8e9b Mon Sep 17 00:00:00 2001 From: fbchow Date: Mon, 4 Nov 2019 21:35:36 -0800 Subject: [PATCH 2/4] Remove tests for each classifier Co-authored-by: Sallie Walecka (cherry picked from commit e6bced804b6c834e83b5cad32626cf33b705f7f9) --- sklearn/tests/test_dummy.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index a032e6f3b1ad6..911c91bf19468 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -580,23 +580,6 @@ def test_sample_weight_invalid(): clf = DummyClassifier().fit(X, y, sample_weight) assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2]) - sample_weight = np.random.rand(3, 1) - with pytest.raises(ValueError): - clf.fit(X, y, sample_weight=sample_weight) - - sample_weight = np.array(0) - expected_err = r"Singleton.* cannot be considered a valid collection" - with pytest.raises(TypeError, match=expected_err): - clf.fit(X, y, sample_weight=sample_weight) - - sample_weight = np.ones(4) - with pytest.raises(ValueError): - clf.fit(X, y, sample_weight=sample_weight) - - sample_weight = np.ones(2) - with pytest.raises(ValueError): - clf.fit(X, y, sample_weight=sample_weight) - def test_constant_strategy_sparse_target(): X = [[0]] * 5 # ignored From ee6ce0bd28b5864de4e713f24b6c9ede17fa39e6 Mon Sep 17 00:00:00 2001 From: Sallie Walecka Date: Tue, 12 Nov 2019 16:20:08 -0800 Subject: [PATCH 3/4] Remove duplicate test --- sklearn/tests/test_dummy.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 911c91bf19468..55f3abc77b0de 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -571,16 +571,6 @@ def test_classification_sample_weight(): assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2]) -def test_sample_weight_invalid(): - # Check sample weighting raises errors. - X = [[0], [0], [1]] - y = [0, 1, 0] - sample_weight = [0.1, 1., 0.1] - - clf = DummyClassifier().fit(X, y, sample_weight) - assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2]) - - def test_constant_strategy_sparse_target(): X = [[0]] * 5 # ignored y = sp.csc_matrix(np.array([[0, 1], From 8509e85d69a10ed8e5519757bc6752992fb63cba Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 13:07:43 +0100 Subject: [PATCH 4/4] FIX Remove redundant checks --- sklearn/dummy.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 54b8c3f1a6168..9f99b0358b16e 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -13,8 +13,7 @@ from .utils.validation import _num_samples from .utils.validation import check_array from .utils.validation import check_consistent_length -from .utils.validation import check_is_fitted -from .utils.validation import _check_sample_weight +from .utils.validation import check_is_fitted, _check_sample_weight from .utils.random import _random_choice_csc from .utils.stats import _weighted_percentile from .utils.multiclass import class_distribution @@ -159,9 +158,6 @@ def fit(self, X, y, sample_weight=None): check_consistent_length(X, y) - if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) - if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X)