-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH Adds support for missing values in Random Forest #26391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c204d09
74a1a79
c9d6d3c
4b77ad1
2ef4df4
26cba53
6c2b6e2
55ffce2
3caaf3a
ed7a843
a37a574
8f29608
b7d09ec
4496468
7be2801
99dc00d
ab1a4f4
66dad75
d0eb5a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ class calls the ``fit`` method of each sub-estimator on random samples | |
from ..tree._tree import DOUBLE, DTYPE | ||
from ..utils import check_random_state, compute_sample_weight | ||
from ..utils._param_validation import Interval, RealNotInt, StrOptions | ||
from ..utils._tags import _safe_tags | ||
from ..utils.multiclass import check_classification_targets, type_of_target | ||
from ..utils.parallel import Parallel, delayed | ||
from ..utils.validation import ( | ||
|
@@ -159,6 +160,7 @@ def _parallel_build_trees( | |
verbose=0, | ||
class_weight=None, | ||
n_samples_bootstrap=None, | ||
missing_values_in_feature_mask=None, | ||
): | ||
""" | ||
Private function used to fit a single tree in parallel.""" | ||
|
@@ -185,9 +187,21 @@ def _parallel_build_trees( | |
elif class_weight == "balanced_subsample": | ||
curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices) | ||
|
||
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False) | ||
tree._fit( | ||
X, | ||
y, | ||
sample_weight=curr_sample_weight, | ||
check_input=False, | ||
missing_values_in_feature_mask=missing_values_in_feature_mask, | ||
) | ||
else: | ||
tree.fit(X, y, sample_weight=sample_weight, check_input=False) | ||
tree._fit( | ||
X, | ||
y, | ||
sample_weight=sample_weight, | ||
check_input=False, | ||
missing_values_in_feature_mask=missing_values_in_feature_mask, | ||
) | ||
|
||
return tree | ||
|
||
|
@@ -345,9 +359,26 @@ def fit(self, X, y, sample_weight=None): | |
# Validate or convert input data | ||
if issparse(y): | ||
raise ValueError("sparse multilabel-indicator for y is not supported.") | ||
|
||
X, y = self._validate_data( | ||
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE | ||
X, | ||
y, | ||
multi_output=True, | ||
accept_sparse="csc", | ||
dtype=DTYPE, | ||
force_all_finite=False, | ||
) | ||
# _compute_missing_values_in_feature_mask checks if X has missing values and | ||
# will raise an error if the underlying tree base estimator can't handle missing | ||
# values. Only the criterion is required to determine if the tree supports | ||
# missing values. | ||
estimator = type(self.estimator)(criterion=self.criterion) | ||
missing_values_in_feature_mask = ( | ||
estimator._compute_missing_values_in_feature_mask( | ||
X, estimator_name=self.__class__.__name__ | ||
) | ||
) | ||
|
||
if sample_weight is not None: | ||
sample_weight = _check_sample_weight(sample_weight, X) | ||
|
||
|
@@ -469,6 +500,7 @@ def fit(self, X, y, sample_weight=None): | |
verbose=self.verbose, | ||
class_weight=self.class_weight, | ||
n_samples_bootstrap=n_samples_bootstrap, | ||
missing_values_in_feature_mask=missing_values_in_feature_mask, | ||
) | ||
for i, t in enumerate(trees) | ||
) | ||
|
@@ -596,7 +628,18 @@ def _validate_X_predict(self, X): | |
""" | ||
Validate X whenever one tries to predict, apply, predict_proba.""" | ||
check_is_fitted(self) | ||
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False) | ||
if self.estimators_[0]._support_missing_values(X): | ||
force_all_finite = "allow-nan" | ||
else: | ||
force_all_finite = True | ||
|
||
X = self._validate_data( | ||
X, | ||
dtype=DTYPE, | ||
accept_sparse="csr", | ||
reset=False, | ||
force_all_finite=force_all_finite, | ||
) | ||
if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc): | ||
raise ValueError("No support for np.int64 index based sparse matrices") | ||
return X | ||
|
@@ -636,6 +679,12 @@ def feature_importances_(self): | |
all_importances = np.mean(all_importances, axis=0, dtype=np.float64) | ||
return all_importances / np.sum(all_importances) | ||
|
||
def _more_tags(self): | ||
# Only the criterion is required to determine if the tree supports | ||
# missing values | ||
estimator = type(self.estimator)(criterion=self.criterion) | ||
return {"allow_nan": _safe_tags(estimator, key="allow_nan")} | ||
Comment on lines
+682
to
+686
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we test the support or non-support of the criteria? |
||
|
||
|
||
def _accumulate_prediction(predict, X, out, lock): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1809,3 +1809,91 @@ def test_round_samples_to_one_when_samples_too_low(class_weight): | |||||
n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0 | ||||||
) | ||||||
forest.fit(X, y) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize( | ||||||
"make_data, Forest", | ||||||
[ | ||||||
(datasets.make_regression, RandomForestRegressor), | ||||||
(datasets.make_classification, RandomForestClassifier), | ||||||
], | ||||||
) | ||||||
def test_missing_values_is_resilient(make_data, Forest): | ||||||
"""Check that forest can deal with missing values and have decent performance.""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
rng = np.random.RandomState(0) | ||||||
n_samples, n_features = 1000, 10 | ||||||
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng) | ||||||
|
||||||
# Create dataset with missing values | ||||||
X_missing = X.copy() | ||||||
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an assertion that |
||||||
X_missing_train, X_missing_test, y_train, y_test = train_test_split( | ||||||
X_missing, y, random_state=0 | ||||||
) | ||||||
|
||||||
# Train forest with missing values | ||||||
forest_with_missing = Forest(random_state=rng, n_estimators=50) | ||||||
forest_with_missing.fit(X_missing_train, y_train) | ||||||
score_with_missing = forest_with_missing.score(X_missing_test, y_test) | ||||||
|
||||||
# Train forest without missing values | ||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | ||||||
forest = Forest(random_state=rng, n_estimators=50) | ||||||
forest.fit(X_train, y_train) | ||||||
score_without_missing = forest.score(X_test, y_test) | ||||||
|
||||||
# Score is still 80 percent of the forest's score that had no missing values | ||||||
assert score_with_missing >= 0.80 * score_without_missing | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor]) | ||||||
def test_missing_value_is_predictive(Forest): | ||||||
"""Check that the forest learns when missing values are only present for | ||||||
a predictive feature.""" | ||||||
rng = np.random.RandomState(0) | ||||||
n_samples = 300 | ||||||
|
||||||
X_non_predictive = rng.standard_normal(size=(n_samples, 10)) | ||||||
y = rng.randint(0, high=2, size=n_samples) | ||||||
|
||||||
# Create a predictive feature using `y` and with some noise | ||||||
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05]) | ||||||
y_mask = y.astype(bool) | ||||||
y_mask[X_random_mask] = ~y_mask[X_random_mask] | ||||||
|
||||||
predictive_feature = rng.standard_normal(size=n_samples) | ||||||
predictive_feature[y_mask] = np.nan | ||||||
|
||||||
X_predictive = X_non_predictive.copy() | ||||||
X_predictive[:, 5] = predictive_feature | ||||||
|
||||||
( | ||||||
X_predictive_train, | ||||||
X_predictive_test, | ||||||
X_non_predictive_train, | ||||||
X_non_predictive_test, | ||||||
y_train, | ||||||
y_test, | ||||||
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0) | ||||||
forest_predictive = Forest(random_state=0).fit(X_predictive_train, y_train) | ||||||
forest_non_predictive = Forest(random_state=0).fit(X_non_predictive_train, y_train) | ||||||
|
||||||
predictive_test_score = forest_predictive.score(X_predictive_test, y_test) | ||||||
|
||||||
assert predictive_test_score >= 0.75 | ||||||
assert predictive_test_score >= forest_non_predictive.score( | ||||||
X_non_predictive_test, y_test | ||||||
) | ||||||
|
||||||
|
||||||
def test_non_supported_criterion_raises_error_with_missing_values(): | ||||||
"""Raise error for unsupported criterion when there are missing values.""" | ||||||
X = np.array([[0, 1, 2], [np.nan, 0, 2.0]]) | ||||||
y = [0.5, 1.0] | ||||||
|
||||||
forest = RandomForestRegressor(criterion="absolute_error") | ||||||
|
||||||
msg = "RandomForestRegressor does not accept missing values" | ||||||
with pytest.raises(ValueError, match=msg): | ||||||
forest.fit(X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to have a discussion about whether or not nans should be treated automatically as missing values. The issue I see here is that if this is in a pipeline and you have a bugged transformer before that which outputs nans, it will be silently ignored by the random forest.
Maybe missing value support should be enabled through a parameter, and/or through the config context. Maybe we should add an after-fit check in our transformers to ensure they did not create nans.
Anyway, this is consistent with the current behavior of HGBT so I'm fine merging it as is for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 to your conclusion.
To continue your line of thought: how often do we think it happens that someone wants NaNs to mean missing value vs NaNs appearing because of a bug? Based on what we think is more likely I think we should either make the handling automatic (with an opt-in to get a warning/exception) or make warning/exception the default that needs optint-out of.