From f3533a03b6c3d1abf61311bb2c3928610cac154d Mon Sep 17 00:00:00 2001 From: fabianegli Date: Mon, 19 Sep 2016 14:19:16 +0200 Subject: [PATCH 1/7] Throw an error with explicit message if n_estimators is not an integer. --- sklearn/ensemble/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/ensemble/base.py b/sklearn/ensemble/base.py index 9610f3402c8fb..b689c8f1467f1 100644 --- a/sklearn/ensemble/base.py +++ b/sklearn/ensemble/base.py @@ -6,6 +6,7 @@ # License: BSD 3 clause import numpy as np +import numbers from ..base import clone from ..base import BaseEstimator @@ -55,6 +56,10 @@ def __init__(self, base_estimator, n_estimators=10, def _validate_estimator(self, default=None): """Check the estimator and the n_estimator attribute, set the `base_estimator_` attribute.""" + if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): + raise ValueError("n_estimators must be an integer, " + "got {0}.".format(type(self.n_estimators))) + if self.n_estimators <= 0: raise ValueError("n_estimators must be greater than zero, " "got {0}.".format(self.n_estimators)) From 0a4bf16703cfd090bc5b384804133cd774b3c230 Mon Sep 17 00:00:00 2001 From: fabianegli Date: Mon, 19 Sep 2016 16:30:51 +0200 Subject: [PATCH 2/7] Testing for explicit message if n_estimators is not an integer. --- sklearn/ensemble/tests/test_base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index 0268715cde9ef..4915f2f6edc79 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -41,3 +41,13 @@ def test_base_zero_n_estimators(): assert_raise_message(ValueError, "n_estimators must be greater than zero, got 0.", ensemble.fit, iris.data, iris.target) + + +def test_base_not_int_n_estimators(): + # Check that instantiating a BaseEnsemble with a string as n_estimators raises + # a ValueError requesting n_estimators to be supplied as an integer. + ensemble_string = BaggingClassifier(base_estimator=Perceptron(), n_estimators='3') + iris = load_iris() + assert_raise_message(ValueError, + "n_estmators must be an integer", + ensemble_string.fit, iris.data, iris.target) From 24a03795f18198bd428e0ab09acca8f461514bcc Mon Sep 17 00:00:00 2001 From: fabianegli Date: Mon, 19 Sep 2016 17:12:23 +0200 Subject: [PATCH 3/7] Fixed typo in test for explicit message if n_estimators is an integer. --- sklearn/ensemble/tests/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index 4915f2f6edc79..cd6fba73f5b2b 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -49,5 +49,5 @@ def test_base_not_int_n_estimators(): ensemble_string = BaggingClassifier(base_estimator=Perceptron(), n_estimators='3') iris = load_iris() assert_raise_message(ValueError, - "n_estmators must be an integer", + "n_estimators must be an integer", ensemble_string.fit, iris.data, iris.target) From 652c161577dfc6bad08c295a77c2cd0c6334b6eb Mon Sep 17 00:00:00 2001 From: fabianegli Date: Tue, 20 Sep 2016 00:17:55 +0200 Subject: [PATCH 4/7] Added tests for np.int32 and float input. --- sklearn/ensemble/tests/test_base.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index cd6fba73f5b2b..b54cda9aa5745 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -32,6 +32,10 @@ def test_base(): assert_true(isinstance(ensemble[0], Perceptron)) + np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(), + n_estimators=np.int32(3)) + np_int_e.fit(iris.data, iris.target) + def test_base_zero_n_estimators(): # Check that instantiating a BaseEnsemble with n_estimators<=0 raises @@ -46,8 +50,12 @@ def test_base_zero_n_estimators(): def test_base_not_int_n_estimators(): # Check that instantiating a BaseEnsemble with a string as n_estimators raises # a ValueError requesting n_estimators to be supplied as an integer. - ensemble_string = BaggingClassifier(base_estimator=Perceptron(), n_estimators='3') + string_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators='3') iris = load_iris() assert_raise_message(ValueError, "n_estimators must be an integer", - ensemble_string.fit, iris.data, iris.target) + string_ensemble.fit, iris.data, iris.target) + float_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=3.0) + assert_raise_message(ValueError, + "n_estimators must be an integer", + float_ensemble.fit, iris.data, iris.target) From c73f7a6736051d6c72b4af71b9cf4a2873ed682c Mon Sep 17 00:00:00 2001 From: fabianegli Date: Tue, 20 Sep 2016 00:30:54 +0200 Subject: [PATCH 5/7] pep8 compliance --- sklearn/ensemble/tests/test_base.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index b54cda9aa5745..b1736fa9c5539 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -32,7 +32,7 @@ def test_base(): assert_true(isinstance(ensemble[0], Perceptron)) - np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(), + np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=np.int32(3)) np_int_e.fit(iris.data, iris.target) @@ -40,7 +40,8 @@ def test_base(): def test_base_zero_n_estimators(): # Check that instantiating a BaseEnsemble with n_estimators<=0 raises # a ValueError. - ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=0) + ensemble = BaggingClassifier(base_estimator=Perceptron(), + n_estimators=0) iris = load_iris() assert_raise_message(ValueError, "n_estimators must be greater than zero, got 0.", @@ -48,14 +49,16 @@ def test_base_zero_n_estimators(): def test_base_not_int_n_estimators(): - # Check that instantiating a BaseEnsemble with a string as n_estimators raises - # a ValueError requesting n_estimators to be supplied as an integer. - string_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators='3') + # Check that instantiating a BaseEnsemble with a string as n_estimators + # raises a ValueError demanding n_estimators to be supplied as an integer. + string_ensemble = BaggingClassifier(base_estimator=Perceptron(), + n_estimators='3') iris = load_iris() assert_raise_message(ValueError, "n_estimators must be an integer", string_ensemble.fit, iris.data, iris.target) - float_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=3.0) + float_ensemble = BaggingClassifier(base_estimator=Perceptron(), + n_estimators=3.0) assert_raise_message(ValueError, "n_estimators must be an integer", float_ensemble.fit, iris.data, iris.target) From 05c1dcd9a077bf6996497c074043f62dd817dfbe Mon Sep 17 00:00:00 2001 From: fabianegli Date: Tue, 20 Sep 2016 00:35:02 +0200 Subject: [PATCH 6/7] fix function name --- sklearn/ensemble/tests/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index b1736fa9c5539..f2e0366977204 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -34,7 +34,7 @@ def test_base(): np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=np.int32(3)) - np_int_e.fit(iris.data, iris.target) + np_int_ensemble.fit(iris.data, iris.target) def test_base_zero_n_estimators(): From 1f7c367c7f8e82a9708ac8547224dee3dea16b21 Mon Sep 17 00:00:00 2001 From: fabianegli Date: Tue, 20 Sep 2016 00:51:42 +0200 Subject: [PATCH 7/7] Import numpy to test n_estimators suplied as numpy int32. --- sklearn/ensemble/tests/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index f2e0366977204..8d2fb1048acb7 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -5,6 +5,7 @@ # Authors: Gilles Louppe # License: BSD 3 clause +import numpy as np from numpy.testing import assert_equal from nose.tools import assert_true