Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH Adds support for missing values in Random Forest #26391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ Changelog
:mod:`sklearn.ensemble`
.......................

- |MajorFeature| :class:`ensemble.RandomForestClassifier` and
:class:`ensemble.RandomForestRegressor` support missing values when
the criterion is `gini`, `entropy`, or `log_loss`,
for classification or `squared_error`, `friedman_mse`, or `poisson`
for regression. :pr:`26391` by `Thomas Fan`_.

- |Feature| :class:`ensemble.RandomForestClassifier`,
:class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier`
and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints,
Expand Down
57 changes: 53 additions & 4 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from ..tree._tree import DOUBLE, DTYPE
from ..utils import check_random_state, compute_sample_weight
from ..utils._param_validation import Interval, RealNotInt, StrOptions
from ..utils._tags import _safe_tags
from ..utils.multiclass import check_classification_targets, type_of_target
from ..utils.parallel import Parallel, delayed
from ..utils.validation import (
Expand Down Expand Up @@ -159,6 +160,7 @@ def _parallel_build_trees(
verbose=0,
class_weight=None,
n_samples_bootstrap=None,
missing_values_in_feature_mask=None,
):
"""
Private function used to fit a single tree in parallel."""
Expand All @@ -185,9 +187,21 @@ def _parallel_build_trees(
elif class_weight == "balanced_subsample":
curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)

tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
tree._fit(
X,
y,
sample_weight=curr_sample_weight,
check_input=False,
missing_values_in_feature_mask=missing_values_in_feature_mask,
)
else:
tree.fit(X, y, sample_weight=sample_weight, check_input=False)
tree._fit(
X,
y,
sample_weight=sample_weight,
check_input=False,
missing_values_in_feature_mask=missing_values_in_feature_mask,
)

return tree

Expand Down Expand Up @@ -345,9 +359,26 @@ def fit(self, X, y, sample_weight=None):
# Validate or convert input data
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")

X, y = self._validate_data(
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
X,
y,
multi_output=True,
accept_sparse="csc",
dtype=DTYPE,
force_all_finite=False,
)
# _compute_missing_values_in_feature_mask checks if X has missing values and
# will raise an error if the underlying tree base estimator can't handle missing
# values. Only the criterion is required to determine if the tree supports
# missing values.
estimator = type(self.estimator)(criterion=self.criterion)
missing_values_in_feature_mask = (
estimator._compute_missing_values_in_feature_mask(
X, estimator_name=self.__class__.__name__
)
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

Expand Down Expand Up @@ -469,6 +500,7 @@ def fit(self, X, y, sample_weight=None):
verbose=self.verbose,
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
missing_values_in_feature_mask=missing_values_in_feature_mask,
)
for i, t in enumerate(trees)
)
Expand Down Expand Up @@ -596,7 +628,18 @@ def _validate_X_predict(self, X):
"""
Validate X whenever one tries to predict, apply, predict_proba."""
check_is_fitted(self)
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
if self.estimators_[0]._support_missing_values(X):
force_all_finite = "allow-nan"
else:
force_all_finite = True
Comment on lines +631 to +634
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to have a discussion about whether or not nans should be treated automatically as missing values. The issue I see here is that if this is in a pipeline and you have a bugged transformer before that which outputs nans, it will be silently ignored by the random forest.

Maybe missing value support should be enabled through a parameter, and/or through the config context. Maybe we should add an after-fit check in our transformers to ensure they did not create nans.

Anyway, this is consistent with the current behavior of HGBT so I'm fine merging it as is for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 to your conclusion.

To continue your line of thought: how often do we think it happens that someone wants NaNs to mean missing value vs NaNs appearing because of a bug? Based on what we think is more likely I think we should either make the handling automatic (with an opt-in to get a warning/exception) or make warning/exception the default that needs optint-out of.


X = self._validate_data(
X,
dtype=DTYPE,
accept_sparse="csr",
reset=False,
force_all_finite=force_all_finite,
)
if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc):
raise ValueError("No support for np.int64 index based sparse matrices")
return X
Expand Down Expand Up @@ -636,6 +679,12 @@ def feature_importances_(self):
all_importances = np.mean(all_importances, axis=0, dtype=np.float64)
return all_importances / np.sum(all_importances)

def _more_tags(self):
# Only the criterion is required to determine if the tree supports
# missing values
estimator = type(self.estimator)(criterion=self.criterion)
return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
Comment on lines +682 to +686
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test the support or non-support of the criteria?



def _accumulate_prediction(predict, X, out, lock):
"""
Expand Down
88 changes: 88 additions & 0 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,3 +1809,91 @@ def test_round_samples_to_one_when_samples_too_low(class_weight):
n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0
)
forest.fit(X, y)


@pytest.mark.parametrize(
"make_data, Forest",
[
(datasets.make_regression, RandomForestRegressor),
(datasets.make_classification, RandomForestClassifier),
],
)
def test_missing_values_is_resilient(make_data, Forest):
"""Check that forest can deal with missing values and have decent performance."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Check that forest can deal with missing values and have decent performance."""
"""Check that forest can deal with missing values and has decent performance."""


rng = np.random.RandomState(0)
n_samples, n_features = 1000, 10
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)

# Create dataset with missing values
X_missing = X.copy()
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assertion that X_missing has indeed np.nan in it.

X_missing_train, X_missing_test, y_train, y_test = train_test_split(
X_missing, y, random_state=0
)

# Train forest with missing values
forest_with_missing = Forest(random_state=rng, n_estimators=50)
forest_with_missing.fit(X_missing_train, y_train)
score_with_missing = forest_with_missing.score(X_missing_test, y_test)

# Train forest without missing values
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
forest = Forest(random_state=rng, n_estimators=50)
forest.fit(X_train, y_train)
score_without_missing = forest.score(X_test, y_test)

# Score is still 80 percent of the forest's score that had no missing values
assert score_with_missing >= 0.80 * score_without_missing


@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor])
def test_missing_value_is_predictive(Forest):
"""Check that the forest learns when missing values are only present for
a predictive feature."""
rng = np.random.RandomState(0)
n_samples = 300

X_non_predictive = rng.standard_normal(size=(n_samples, 10))
y = rng.randint(0, high=2, size=n_samples)

# Create a predictive feature using `y` and with some noise
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
y_mask = y.astype(bool)
y_mask[X_random_mask] = ~y_mask[X_random_mask]

predictive_feature = rng.standard_normal(size=n_samples)
predictive_feature[y_mask] = np.nan

X_predictive = X_non_predictive.copy()
X_predictive[:, 5] = predictive_feature

(
X_predictive_train,
X_predictive_test,
X_non_predictive_train,
X_non_predictive_test,
y_train,
y_test,
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
forest_predictive = Forest(random_state=0).fit(X_predictive_train, y_train)
forest_non_predictive = Forest(random_state=0).fit(X_non_predictive_train, y_train)

predictive_test_score = forest_predictive.score(X_predictive_test, y_test)

assert predictive_test_score >= 0.75
assert predictive_test_score >= forest_non_predictive.score(
X_non_predictive_test, y_test
)


def test_non_supported_criterion_raises_error_with_missing_values():
"""Raise error for unsupported criterion when there are missing values."""
X = np.array([[0, 1, 2], [np.nan, 0, 2.0]])
y = [0.5, 1.0]

forest = RandomForestRegressor(criterion="absolute_error")

msg = "RandomForestRegressor does not accept missing values"
with pytest.raises(ValueError, match=msg):
forest.fit(X, y)
8 changes: 6 additions & 2 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _support_missing_values(self, X):
and self.monotonic_cst is None
)

def _compute_missing_values_in_feature_mask(self, X):
def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
"""Return boolean mask denoting if there are missing values for each feature.

This method also ensures that X is finite.
Expand All @@ -199,13 +199,17 @@ def _compute_missing_values_in_feature_mask(self, X):
X : array-like of shape (n_samples, n_features), dtype=DOUBLE
Input data.

estimator_name : str or None, default=None
Name to use when raising an error. Defaults to the class name.

Returns
-------
missing_values_in_feature_mask : ndarray of shape (n_features,), or None
Missing value mask. If missing values are not supported or there
are no missing values, return None.
"""
common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X")
estimator_name = estimator_name or self.__class__.__name__
common_kwargs = dict(estimator_name=estimator_name, input_name="X")

if not self._support_missing_values(X):
assert_all_finite(X, **common_kwargs)
Expand Down