From c204d0925d9ce10b54797cab121f480632296759 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 16 May 2023 17:07:23 -0400 Subject: [PATCH 01/14] ENH Adds support for missing values in random forest --- doc/whats_new/v1.3.rst | 6 +++ sklearn/ensemble/_forest.py | 56 ++++++++++++++++++++++-- sklearn/ensemble/tests/test_forest.py | 62 +++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 9cab0db995c5d..9f5b3b7531835 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -240,6 +240,12 @@ Changelog :mod:`sklearn.ensemble` ....................... +- |MajorFeature| :class:`ensemble.RandomForestClassifier` and + :class:`ensemble.RandomForestRegressor` support missing values when + the criterion is `gini`, `entropy`, or `log_loss`, + for classification or `squared_error`, `friedman_mse`, or `poisson` + for regression. :pr:`xxxxx` by `Thomas Fan`_. + - |Feature| :class:`ensemble.HistGradientBoostingRegressor` now supports the Gamma deviance loss via `loss="gamma"`. Using the Gamma deviance as loss function comes in handy for modelling skewed diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 16852970bf445..ec1b4124268cc 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -62,9 +62,11 @@ class calls the ``fit`` method of each sub-estimator on random samples ExtraTreeRegressor, ) from ..tree._tree import DTYPE, DOUBLE +from ..base import clone from ..utils import check_random_state, compute_sample_weight from ..exceptions import DataConversionWarning from ._base import BaseEnsemble, _partition_estimators +from ..utils._tags import _safe_tags from ..utils.parallel import delayed, Parallel from ..utils.multiclass import check_classification_targets, type_of_target from ..utils.validation import ( @@ -156,6 +158,7 @@ def _parallel_build_trees( verbose=0, class_weight=None, n_samples_bootstrap=None, + feature_has_missing=None, ): """ Private function used to fit a single tree in parallel.""" @@ -182,9 +185,21 @@ def _parallel_build_trees( elif class_weight == "balanced_subsample": curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices) - tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False) + tree._fit( + X, + y, + sample_weight=curr_sample_weight, + check_input=False, + feature_has_missing=feature_has_missing, + ) else: - tree.fit(X, y, sample_weight=sample_weight, check_input=False) + tree._fit( + X, + y, + sample_weight=sample_weight, + check_input=False, + feature_has_missing=feature_has_missing, + ) return tree @@ -343,9 +358,21 @@ def fit(self, X, y, sample_weight=None): # Validate or convert input data if issparse(y): raise ValueError("sparse multilabel-indicator for y is not supported.") + X, y = self._validate_data( - X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE + X, + y, + multi_output=True, + accept_sparse="csc", + dtype=DTYPE, + force_all_finite=False, ) + # _compute_feature_has_missing checks if X has missing values and will raise + # an error if the underlying tree base estimator can't handle missing values. + estimator = clone(self.estimator) + estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) + feature_has_missing = estimator._compute_feature_has_missing(X) + if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) @@ -467,6 +494,7 @@ def fit(self, X, y, sample_weight=None): verbose=self.verbose, class_weight=self.class_weight, n_samples_bootstrap=n_samples_bootstrap, + feature_has_missing=feature_has_missing, ) for i, t in enumerate(trees) ) @@ -592,7 +620,18 @@ def _validate_X_predict(self, X): """ Validate X whenever one tries to predict, apply, predict_proba.""" check_is_fitted(self) - X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False) + if self.estimators_[0]._support_missing_values(X): + force_all_finite = "allow-nan" + else: + force_all_finite = True + + X = self._validate_data( + X, + dtype=DTYPE, + accept_sparse="csr", + reset=False, + force_all_finite=force_all_finite, + ) if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc): raise ValueError("No support for np.int64 index based sparse matrices") return X @@ -632,6 +671,15 @@ def feature_importances_(self): all_importances = np.mean(all_importances, axis=0, dtype=np.float64) return all_importances / np.sum(all_importances) + def _more_tags(self): + # Ignore errors because the parameters are not validated + try: + estimator = clone(self.estimator) + estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) + return {"allow_nan": _safe_tags(estimator, key="allow_nan")} + except (AttributeError, TypeError): + return {} + def _accumulate_prediction(predict, X, out, lock): """ diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 9bf0bb2becd9b..ca2f30fab6165 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1791,3 +1791,65 @@ def test_round_samples_to_one_when_samples_too_low(class_weight): n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0 ) forest.fit(X, y) + + +@pytest.mark.parametrize( + "make_data, Forest", + [ + (datasets.make_regression, RandomForestRegressor), + (datasets.make_classification, RandomForestClassifier), + ], +) +def test_missing_values_is_resilience(make_data, Forest): + """Check that forest can deal with missing values and have decent performance.""" + + rng = np.random.RandomState(0) + n_samples, n_features = 1000, 10 + X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng) + + # Create dataset with missing values + X_missing = X.copy() + X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan + X_missing_train, X_missing_test, y_train, y_test = train_test_split( + X_missing, y, random_state=0 + ) + + # Train forest with missing values + forest_with_missing = Forest(random_state=rng, n_estimators=50) + forest_with_missing.fit(X_missing_train, y_train) + score_with_missing = forest_with_missing.score(X_missing_test, y_test) + + # Train forest without missing values + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + forest = Forest(random_state=rng, n_estimators=50) + forest.fit(X_train, y_train) + score_without_missing = forest.score(X_test, y_test) + + # Score is still 80 percent of the forest's score that had no missing values + assert score_with_missing >= 0.80 * score_without_missing + + +@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor]) +def test_missing_value_is_predictive(Forest): + """Check the forest learns when only the missing value is predictive.""" + rng = np.random.RandomState(0) + n_samples = 1000 + + X = rng.standard_normal(size=(n_samples, 10)) + y = rng.randint(0, high=2, size=n_samples) + + # Create a predictive feature using `y` and with some noise + X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05]) + y_mask = y.copy().astype(bool) + y_mask[X_random_mask] = ~y_mask[X_random_mask] + + X_predictive = rng.standard_normal(size=n_samples) + X_predictive[y_mask] = np.nan + + X[:, 5] = X_predictive + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) + forest = Forest(random_state=rng, n_estimators=50).fit(X_train, y_train) + + assert forest.score(X_train, y_train) >= 0.85 + assert forest.score(X_test, y_test) >= 0.80 From 74a1a7964a89dfe4f4817715fe30d49cf28af21c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 17 May 2023 10:52:59 -0400 Subject: [PATCH 02/14] DOC Adds PR number --- doc/whats_new/v1.3.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 9f5b3b7531835..124305256f8bc 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -244,7 +244,7 @@ Changelog :class:`ensemble.RandomForestRegressor` support missing values when the criterion is `gini`, `entropy`, or `log_loss`, for classification or `squared_error`, `friedman_mse`, or `poisson` - for regression. :pr:`xxxxx` by `Thomas Fan`_. + for regression. :pr:`26391` by `Thomas Fan`_. - |Feature| :class:`ensemble.HistGradientBoostingRegressor` now supports the Gamma deviance loss via `loss="gamma"`. @@ -415,7 +415,7 @@ Changelog - |API| The `eps` parameter of the :func:`log_loss` has been deprecated and will be removed in 1.5. :pr:`25299` by :user:`Omar Salman `. - + - |Feature| :func:`metrics.average_precision_score` now supports the multiclass case. :pr:`17388` by :user:`Geoffrey Bolmier ` and From c9d6d3c4bbaac0ddaef6aabb222d10f3a879accc Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 18 May 2023 13:48:24 -0400 Subject: [PATCH 03/14] TST Fix --- sklearn/ensemble/tests/test_forest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index ca2f30fab6165..6d3334d6ba210 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1848,8 +1848,8 @@ def test_missing_value_is_predictive(Forest): X[:, 5] = X_predictive - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) - forest = Forest(random_state=rng, n_estimators=50).fit(X_train, y_train) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + forest = Forest(random_state=0).fit(X_train, y_train) assert forest.score(X_train, y_train) >= 0.85 - assert forest.score(X_test, y_test) >= 0.80 + assert forest.score(X_test, y_test) >= 0.75 From 4b77ad14352064d7908287ddae600d6193353724 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 18 May 2023 13:52:29 -0400 Subject: [PATCH 04/14] TXT Formatting fix --- sklearn/ensemble/tests/test_forest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 6d3334d6ba210..3ab555fa6566f 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1840,7 +1840,7 @@ def test_missing_value_is_predictive(Forest): # Create a predictive feature using `y` and with some noise X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05]) - y_mask = y.copy().astype(bool) + y_mask = y.astype(bool) y_mask[X_random_mask] = ~y_mask[X_random_mask] X_predictive = rng.standard_normal(size=n_samples) From 2ef4df4736c677d1f23b10ac8f457565d83000b0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 18 May 2023 15:37:41 -0400 Subject: [PATCH 05/14] TST Lower bound --- sklearn/ensemble/tests/test_forest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 3ab555fa6566f..e9c6e845824b3 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1851,5 +1851,5 @@ def test_missing_value_is_predictive(Forest): X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) forest = Forest(random_state=0).fit(X_train, y_train) - assert forest.score(X_train, y_train) >= 0.85 + assert forest.score(X_train, y_train) >= 0.80 assert forest.score(X_test, y_test) >= 0.75 From 6c2b6e201bc06510a1381c747a3e0436b9059004 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 25 May 2023 13:43:20 -0400 Subject: [PATCH 06/14] Simplify logic --- sklearn/ensemble/_forest.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index ec1b4124268cc..5f1440d5a799a 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -62,7 +62,6 @@ class calls the ``fit`` method of each sub-estimator on random samples ExtraTreeRegressor, ) from ..tree._tree import DTYPE, DOUBLE -from ..base import clone from ..utils import check_random_state, compute_sample_weight from ..exceptions import DataConversionWarning from ._base import BaseEnsemble, _partition_estimators @@ -369,8 +368,7 @@ def fit(self, X, y, sample_weight=None): ) # _compute_feature_has_missing checks if X has missing values and will raise # an error if the underlying tree base estimator can't handle missing values. - estimator = clone(self.estimator) - estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) + estimator = type(self.estimator)(criterion=self.criterion) feature_has_missing = estimator._compute_feature_has_missing(X) if sample_weight is not None: @@ -672,12 +670,10 @@ def feature_importances_(self): return all_importances / np.sum(all_importances) def _more_tags(self): - # Ignore errors because the parameters are not validated - try: - estimator = clone(self.estimator) - estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) + if isinstance(self.estimator, BaseDecisionTree): + estimator = type(self.estimator)(criterion=self.criterion) return {"allow_nan": _safe_tags(estimator, key="allow_nan")} - except (AttributeError, TypeError): + else: return {} From 55ffce2e210b5c5d87d7947ff1e6aa0395997f62 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 25 May 2023 14:02:55 -0400 Subject: [PATCH 07/14] DOC Adds comment --- sklearn/ensemble/_forest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 5f1440d5a799a..6c13ebb902a7a 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -368,6 +368,8 @@ def fit(self, X, y, sample_weight=None): ) # _compute_feature_has_missing checks if X has missing values and will raise # an error if the underlying tree base estimator can't handle missing values. + # Only the criterion is required to determine if the tree supports missing + # values. estimator = type(self.estimator)(criterion=self.criterion) feature_has_missing = estimator._compute_feature_has_missing(X) @@ -671,6 +673,8 @@ def feature_importances_(self): def _more_tags(self): if isinstance(self.estimator, BaseDecisionTree): + # Only the criterion is required to determine if the tree supports + # missing values estimator = type(self.estimator)(criterion=self.criterion) return {"allow_nan": _safe_tags(estimator, key="allow_nan")} else: From 3caaf3add488b8a3a6540b0c5e6cea089286bf01 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 25 May 2023 14:08:54 -0400 Subject: [PATCH 08/14] Apply suggestions from code review Co-authored-by: Tim Head --- sklearn/ensemble/tests/test_forest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index e9c6e845824b3..704373a3abee9 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1800,7 +1800,7 @@ def test_round_samples_to_one_when_samples_too_low(class_weight): (datasets.make_classification, RandomForestClassifier), ], ) -def test_missing_values_is_resilience(make_data, Forest): +def test_missing_values_is_resilient(make_data, Forest): """Check that forest can deal with missing values and have decent performance.""" rng = np.random.RandomState(0) From ed7a843201a5bf31a18d1fc29623a212f23267c9 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 26 May 2023 21:51:04 -0400 Subject: [PATCH 09/14] CLN Remove unneeded code --- sklearn/ensemble/_forest.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 6c13ebb902a7a..6a6d189821e3d 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -672,13 +672,10 @@ def feature_importances_(self): return all_importances / np.sum(all_importances) def _more_tags(self): - if isinstance(self.estimator, BaseDecisionTree): - # Only the criterion is required to determine if the tree supports - # missing values - estimator = type(self.estimator)(criterion=self.criterion) - return {"allow_nan": _safe_tags(estimator, key="allow_nan")} - else: - return {} + # Only the criterion is required to determine if the tree supports + # missing values + estimator = type(self.estimator)(criterion=self.criterion) + return {"allow_nan": _safe_tags(estimator, key="allow_nan")} def _accumulate_prediction(predict, X, out, lock): From a37a5747ceb38358c9b6f846ac0a74df8d0852ec Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 14 Jun 2023 16:57:12 +0100 Subject: [PATCH 10/14] TST Updates tests based on review --- sklearn/ensemble/_forest.py | 4 ++- sklearn/ensemble/tests/test_forest.py | 43 +++++++++++++++++++++------ sklearn/tree/_classes.py | 8 +++-- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 6a6d189821e3d..f96d0efc38f63 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -371,7 +371,9 @@ def fit(self, X, y, sample_weight=None): # Only the criterion is required to determine if the tree supports missing # values. estimator = type(self.estimator)(criterion=self.criterion) - feature_has_missing = estimator._compute_feature_has_missing(X) + feature_has_missing = estimator._compute_feature_has_missing( + X, estimator_name=self.__class__.__name__ + ) if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 704373a3abee9..c649941a42c08 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1833,9 +1833,9 @@ def test_missing_values_is_resilient(make_data, Forest): def test_missing_value_is_predictive(Forest): """Check the forest learns when only the missing value is predictive.""" rng = np.random.RandomState(0) - n_samples = 1000 + n_samples = 300 - X = rng.standard_normal(size=(n_samples, 10)) + X_non_predictive = rng.standard_normal(size=(n_samples, 10)) y = rng.randint(0, high=2, size=n_samples) # Create a predictive feature using `y` and with some noise @@ -1843,13 +1843,38 @@ def test_missing_value_is_predictive(Forest): y_mask = y.astype(bool) y_mask[X_random_mask] = ~y_mask[X_random_mask] - X_predictive = rng.standard_normal(size=n_samples) - X_predictive[y_mask] = np.nan + predictive_feature = rng.standard_normal(size=n_samples) + predictive_feature[y_mask] = np.nan - X[:, 5] = X_predictive + X_predictive = X_non_predictive.copy() + X_predictive[:, 5] = predictive_feature - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - forest = Forest(random_state=0).fit(X_train, y_train) + ( + X_predictive_train, + X_predictive_test, + X_non_predictive_train, + X_non_predictive_test, + y_train, + y_test, + ) = train_test_split(X_predictive, X_non_predictive, y, random_state=0) + forest_predictive = Forest(random_state=0).fit(X_predictive_train, y_train) + forest_non_predictive = Forest(random_state=0).fit(X_non_predictive_train, y_train) + + predictive_test_score = forest_predictive.score(X_predictive_test, y_test) + + assert predictive_test_score >= 0.75 + assert predictive_test_score >= forest_non_predictive.score( + X_non_predictive_test, y_test + ) - assert forest.score(X_train, y_train) >= 0.80 - assert forest.score(X_test, y_test) >= 0.75 + +def test_non_supported_criterion_raises_error_with_missing_values(): + """Raise error for unsupported criterion when there are missing values.""" + X = np.array([[0, 1, 2], [np.nan, 0, 2.0]]) + y = [0.5, 1.0] + + forest = RandomForestRegressor(criterion="absolute_error") + + msg = "RandomForestRegressor does not accept missing values" + with pytest.raises(ValueError, match=msg): + forest.fit(X, y) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e4a3b0a9ee3af..ef239c4fa0b16 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -180,7 +180,7 @@ def get_n_leaves(self): def _support_missing_values(self, X): return not issparse(X) and self._get_tags()["allow_nan"] - def _compute_feature_has_missing(self, X): + def _compute_feature_has_missing(self, X, estimator_name=None): """Return boolean mask denoting if there are missing values for each feature. This method also ensures that X is finite. @@ -190,13 +190,17 @@ def _compute_feature_has_missing(self, X): X : array-like of shape (n_samples, n_features), dtype=DOUBLE Input data. + estimator_name : str or None, default=None + Name to use when raising an error. Defaults to the class name. + Returns ------- feature_has_missing : ndarray of shape (n_features,), or None Missing value mask. If missing values are not supported or there are no missing values, return None. """ - common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X") + estimator_name = estimator_name or self.__class__.__name__ + common_kwargs = dict(estimator_name=estimator_name, input_name="X") if not self._support_missing_values(X): assert_all_finite(X, **common_kwargs) From b7d09ec416a391943b9d4cdbd320699599ef565b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 15 Jun 2023 09:04:40 +0100 Subject: [PATCH 11/14] FIX Fixes merge issues --- sklearn/ensemble/_forest.py | 10 +++++----- sklearn/tree/_classes.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 63446c8178806..f4a398a735b62 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -366,12 +366,12 @@ def fit(self, X, y, sample_weight=None): dtype=DTYPE, force_all_finite=False, ) - # _compute_feature_has_missing checks if X has missing values and will raise - # an error if the underlying tree base estimator can't handle missing values. - # Only the criterion is required to determine if the tree supports missing - # values. + # _compute_missing_values_in_feature_mask checks if X has missing values and + # will raise an error if the underlying tree base estimator can't handle missing + # values. Only the criterion is required to determine if the tree supports + # missing values. estimator = type(self.estimator)(criterion=self.criterion) - feature_has_missing = estimator._compute_feature_has_missing( + feature_has_missing = estimator._compute_missing_values_in_feature_mask( X, estimator_name=self.__class__.__name__ ) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 756526caad63f..e4a34a321816d 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -181,7 +181,7 @@ def get_n_leaves(self): def _support_missing_values(self, X): return not issparse(X) and self._get_tags()["allow_nan"] - def _compute_missing_values_in_feature_mask(self, X): + def _compute_missing_values_in_feature_mask(self, X, estimator_name=None): """Return boolean mask denoting if there are missing values for each feature. This method also ensures that X is finite. From 99dc00d6afd094f45e87de7cd89851c783b253d2 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Jul 2023 15:22:10 -0400 Subject: [PATCH 12/14] FIX Fixes errors --- sklearn/ensemble/_forest.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index f1162e6cad9c1..eecd13d403744 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -160,7 +160,7 @@ def _parallel_build_trees( verbose=0, class_weight=None, n_samples_bootstrap=None, - feature_has_missing=None, + missing_values_in_feature_mask=None, ): """ Private function used to fit a single tree in parallel.""" @@ -192,7 +192,7 @@ def _parallel_build_trees( y, sample_weight=curr_sample_weight, check_input=False, - feature_has_missing=feature_has_missing, + missing_values_in_feature_mask=missing_values_in_feature_mask, ) else: tree._fit( @@ -200,7 +200,7 @@ def _parallel_build_trees( y, sample_weight=sample_weight, check_input=False, - feature_has_missing=feature_has_missing, + missing_values_in_feature_mask=missing_values_in_feature_mask, ) return tree @@ -373,8 +373,10 @@ def fit(self, X, y, sample_weight=None): # values. Only the criterion is required to determine if the tree supports # missing values. estimator = type(self.estimator)(criterion=self.criterion) - feature_has_missing = estimator._compute_missing_values_in_feature_mask( - X, estimator_name=self.__class__.__name__ + missing_values_in_feature_mask = ( + estimator._compute_missing_values_in_feature_mask( + X, estimator_name=self.__class__.__name__ + ) ) if sample_weight is not None: @@ -498,7 +500,7 @@ def fit(self, X, y, sample_weight=None): verbose=self.verbose, class_weight=self.class_weight, n_samples_bootstrap=n_samples_bootstrap, - feature_has_missing=feature_has_missing, + missing_values_in_feature_mask=missing_values_in_feature_mask, ) for i, t in enumerate(trees) ) From ab1a4f4774b8717d753b1d059382d1b714539fc3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Jul 2023 20:20:18 -0400 Subject: [PATCH 13/14] Apply suggestions from code review Co-authored-by: Julien Jerphanion --- sklearn/ensemble/tests/test_forest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index f89972a28fdb4..72111c9bb481c 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1849,7 +1849,8 @@ def test_missing_values_is_resilient(make_data, Forest): @pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor]) def test_missing_value_is_predictive(Forest): - """Check the forest learns when only the missing value is predictive.""" + """Check that the forest learns when missing values are only present for + a predictive feature.""" rng = np.random.RandomState(0) n_samples = 300 From 66dad75e21aed6e58823b32c3e087ecd34ce0773 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 21 Jul 2023 09:04:39 -0400 Subject: [PATCH 14/14] DOC Move to 1.4 --- doc/whats_new/v1.3.rst | 6 ------ doc/whats_new/v1.4.rst | 6 ++++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 80206c39bf8d7..8d39ca2fed143 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -360,12 +360,6 @@ Changelog :mod:`sklearn.ensemble` ....................... -- |MajorFeature| :class:`ensemble.RandomForestClassifier` and - :class:`ensemble.RandomForestRegressor` support missing values when - the criterion is `gini`, `entropy`, or `log_loss`, - for classification or `squared_error`, `friedman_mse`, or `poisson` - for regression. :pr:`26391` by `Thomas Fan`_. - - |Feature| :class:`ensemble.HistGradientBoostingRegressor` now supports the Gamma deviance loss via `loss="gamma"`. Using the Gamma deviance as loss function comes in handy for modelling skewed diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index ff3be92064fe8..c55fb174b6b4f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -85,6 +85,12 @@ Changelog :mod:`sklearn.ensemble` ....................... +- |MajorFeature| :class:`ensemble.RandomForestClassifier` and + :class:`ensemble.RandomForestRegressor` support missing values when + the criterion is `gini`, `entropy`, or `log_loss`, + for classification or `squared_error`, `friedman_mse`, or `poisson` + for regression. :pr:`26391` by `Thomas Fan`_. + - |Feature| :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints,