Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4094851

Browse files
thomasjpfanbetatimjjerphan
authored
ENH Adds support for missing values in Random Forest (scikit-learn#26391)
Co-authored-by: Tim Head <[email protected]> Co-authored-by: Julien Jerphanion <[email protected]>
1 parent b8d4f46 commit 4094851

File tree

4 files changed

+153
-6
lines changed

4 files changed

+153
-6
lines changed

doc/whats_new/v1.4.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ Changelog
9292
:mod:`sklearn.ensemble`
9393
.......................
9494

95+
- |MajorFeature| :class:`ensemble.RandomForestClassifier` and
96+
:class:`ensemble.RandomForestRegressor` support missing values when
97+
the criterion is `gini`, `entropy`, or `log_loss`,
98+
for classification or `squared_error`, `friedman_mse`, or `poisson`
99+
for regression. :pr:`26391` by `Thomas Fan`_.
100+
95101
- |Feature| :class:`ensemble.RandomForestClassifier`,
96102
:class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier`
97103
and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints,

sklearn/ensemble/_forest.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
7070
from ..tree._tree import DOUBLE, DTYPE
7171
from ..utils import check_random_state, compute_sample_weight
7272
from ..utils._param_validation import Interval, RealNotInt, StrOptions
73+
from ..utils._tags import _safe_tags
7374
from ..utils.multiclass import check_classification_targets, type_of_target
7475
from ..utils.parallel import Parallel, delayed
7576
from ..utils.validation import (
@@ -159,6 +160,7 @@ def _parallel_build_trees(
159160
verbose=0,
160161
class_weight=None,
161162
n_samples_bootstrap=None,
163+
missing_values_in_feature_mask=None,
162164
):
163165
"""
164166
Private function used to fit a single tree in parallel."""
@@ -185,9 +187,21 @@ def _parallel_build_trees(
185187
elif class_weight == "balanced_subsample":
186188
curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)
187189

188-
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
190+
tree._fit(
191+
X,
192+
y,
193+
sample_weight=curr_sample_weight,
194+
check_input=False,
195+
missing_values_in_feature_mask=missing_values_in_feature_mask,
196+
)
189197
else:
190-
tree.fit(X, y, sample_weight=sample_weight, check_input=False)
198+
tree._fit(
199+
X,
200+
y,
201+
sample_weight=sample_weight,
202+
check_input=False,
203+
missing_values_in_feature_mask=missing_values_in_feature_mask,
204+
)
191205

192206
return tree
193207

@@ -345,9 +359,26 @@ def fit(self, X, y, sample_weight=None):
345359
# Validate or convert input data
346360
if issparse(y):
347361
raise ValueError("sparse multilabel-indicator for y is not supported.")
362+
348363
X, y = self._validate_data(
349-
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
364+
X,
365+
y,
366+
multi_output=True,
367+
accept_sparse="csc",
368+
dtype=DTYPE,
369+
force_all_finite=False,
370+
)
371+
# _compute_missing_values_in_feature_mask checks if X has missing values and
372+
# will raise an error if the underlying tree base estimator can't handle missing
373+
# values. Only the criterion is required to determine if the tree supports
374+
# missing values.
375+
estimator = type(self.estimator)(criterion=self.criterion)
376+
missing_values_in_feature_mask = (
377+
estimator._compute_missing_values_in_feature_mask(
378+
X, estimator_name=self.__class__.__name__
379+
)
350380
)
381+
351382
if sample_weight is not None:
352383
sample_weight = _check_sample_weight(sample_weight, X)
353384

@@ -469,6 +500,7 @@ def fit(self, X, y, sample_weight=None):
469500
verbose=self.verbose,
470501
class_weight=self.class_weight,
471502
n_samples_bootstrap=n_samples_bootstrap,
503+
missing_values_in_feature_mask=missing_values_in_feature_mask,
472504
)
473505
for i, t in enumerate(trees)
474506
)
@@ -596,7 +628,18 @@ def _validate_X_predict(self, X):
596628
"""
597629
Validate X whenever one tries to predict, apply, predict_proba."""
598630
check_is_fitted(self)
599-
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
631+
if self.estimators_[0]._support_missing_values(X):
632+
force_all_finite = "allow-nan"
633+
else:
634+
force_all_finite = True
635+
636+
X = self._validate_data(
637+
X,
638+
dtype=DTYPE,
639+
accept_sparse="csr",
640+
reset=False,
641+
force_all_finite=force_all_finite,
642+
)
600643
if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc):
601644
raise ValueError("No support for np.int64 index based sparse matrices")
602645
return X
@@ -636,6 +679,12 @@ def feature_importances_(self):
636679
all_importances = np.mean(all_importances, axis=0, dtype=np.float64)
637680
return all_importances / np.sum(all_importances)
638681

682+
def _more_tags(self):
683+
# Only the criterion is required to determine if the tree supports
684+
# missing values
685+
estimator = type(self.estimator)(criterion=self.criterion)
686+
return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
687+
639688

640689
def _accumulate_prediction(predict, X, out, lock):
641690
"""

sklearn/ensemble/tests/test_forest.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,3 +1809,91 @@ def test_round_samples_to_one_when_samples_too_low(class_weight):
18091809
n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0
18101810
)
18111811
forest.fit(X, y)
1812+
1813+
1814+
@pytest.mark.parametrize(
1815+
"make_data, Forest",
1816+
[
1817+
(datasets.make_regression, RandomForestRegressor),
1818+
(datasets.make_classification, RandomForestClassifier),
1819+
],
1820+
)
1821+
def test_missing_values_is_resilient(make_data, Forest):
1822+
"""Check that forest can deal with missing values and have decent performance."""
1823+
1824+
rng = np.random.RandomState(0)
1825+
n_samples, n_features = 1000, 10
1826+
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
1827+
1828+
# Create dataset with missing values
1829+
X_missing = X.copy()
1830+
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
1831+
X_missing_train, X_missing_test, y_train, y_test = train_test_split(
1832+
X_missing, y, random_state=0
1833+
)
1834+
1835+
# Train forest with missing values
1836+
forest_with_missing = Forest(random_state=rng, n_estimators=50)
1837+
forest_with_missing.fit(X_missing_train, y_train)
1838+
score_with_missing = forest_with_missing.score(X_missing_test, y_test)
1839+
1840+
# Train forest without missing values
1841+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
1842+
forest = Forest(random_state=rng, n_estimators=50)
1843+
forest.fit(X_train, y_train)
1844+
score_without_missing = forest.score(X_test, y_test)
1845+
1846+
# Score is still 80 percent of the forest's score that had no missing values
1847+
assert score_with_missing >= 0.80 * score_without_missing
1848+
1849+
1850+
@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor])
1851+
def test_missing_value_is_predictive(Forest):
1852+
"""Check that the forest learns when missing values are only present for
1853+
a predictive feature."""
1854+
rng = np.random.RandomState(0)
1855+
n_samples = 300
1856+
1857+
X_non_predictive = rng.standard_normal(size=(n_samples, 10))
1858+
y = rng.randint(0, high=2, size=n_samples)
1859+
1860+
# Create a predictive feature using `y` and with some noise
1861+
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
1862+
y_mask = y.astype(bool)
1863+
y_mask[X_random_mask] = ~y_mask[X_random_mask]
1864+
1865+
predictive_feature = rng.standard_normal(size=n_samples)
1866+
predictive_feature[y_mask] = np.nan
1867+
1868+
X_predictive = X_non_predictive.copy()
1869+
X_predictive[:, 5] = predictive_feature
1870+
1871+
(
1872+
X_predictive_train,
1873+
X_predictive_test,
1874+
X_non_predictive_train,
1875+
X_non_predictive_test,
1876+
y_train,
1877+
y_test,
1878+
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
1879+
forest_predictive = Forest(random_state=0).fit(X_predictive_train, y_train)
1880+
forest_non_predictive = Forest(random_state=0).fit(X_non_predictive_train, y_train)
1881+
1882+
predictive_test_score = forest_predictive.score(X_predictive_test, y_test)
1883+
1884+
assert predictive_test_score >= 0.75
1885+
assert predictive_test_score >= forest_non_predictive.score(
1886+
X_non_predictive_test, y_test
1887+
)
1888+
1889+
1890+
def test_non_supported_criterion_raises_error_with_missing_values():
1891+
"""Raise error for unsupported criterion when there are missing values."""
1892+
X = np.array([[0, 1, 2], [np.nan, 0, 2.0]])
1893+
y = [0.5, 1.0]
1894+
1895+
forest = RandomForestRegressor(criterion="absolute_error")
1896+
1897+
msg = "RandomForestRegressor does not accept missing values"
1898+
with pytest.raises(ValueError, match=msg):
1899+
forest.fit(X, y)

sklearn/tree/_classes.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _support_missing_values(self, X):
189189
and self.monotonic_cst is None
190190
)
191191

192-
def _compute_missing_values_in_feature_mask(self, X):
192+
def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
193193
"""Return boolean mask denoting if there are missing values for each feature.
194194
195195
This method also ensures that X is finite.
@@ -199,13 +199,17 @@ def _compute_missing_values_in_feature_mask(self, X):
199199
X : array-like of shape (n_samples, n_features), dtype=DOUBLE
200200
Input data.
201201
202+
estimator_name : str or None, default=None
203+
Name to use when raising an error. Defaults to the class name.
204+
202205
Returns
203206
-------
204207
missing_values_in_feature_mask : ndarray of shape (n_features,), or None
205208
Missing value mask. If missing values are not supported or there
206209
are no missing values, return None.
207210
"""
208-
common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X")
211+
estimator_name = estimator_name or self.__class__.__name__
212+
common_kwargs = dict(estimator_name=estimator_name, input_name="X")
209213

210214
if not self._support_missing_values(X):
211215
assert_all_finite(X, **common_kwargs)

0 commit comments

Comments
 (0)