Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@Vincent-Maladiere
Copy link
Contributor

Reference Issues/PRs

Follow up #20424 and close #17979
Thank you @nicoperetti, @nestornav, @jmloyola and @reshamas for the original PR!

What does this implement/fix? Explain your changes.

Enable multilabel-indicator targets for StackingClassifier by:

  • Removing the first column of each binary output array because predict_proba columns are collinear in the multilabel-indicator context for models like KNeighborsClassifier and RandomForestClassifier —contrary to MLPClassifier that directly outputs array of shape (n_sample, n_classes).
  • Define classes_ during fit in the multilabel-indicator case.
  • Using a different LabelEncoder for each column of y during fit, then decode the target in predict.
  • Add multilabel classification test, including KNeighborsClassifier, MLPClassifier and DummyClassifier as base estimators for parametrized passthrough and stack_method.

Any other comments?

Should we go further in tests by including more checks like sparsity? @jjerphan, @glemaitre

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this follow-up, @Vincent-Maladiere.

Here are a few comments.

@glemaitre glemaitre self-requested a review August 22, 2022 08:10
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of additional comments.

@glemaitre
Copy link
Member

I would keep this PR only on the multilabel case and postpone the multiclass multioutput for later.

@glemaitre glemaitre self-requested a review September 5, 2022 08:57
"stack_method", ["auto", "predict", "predict_proba", "decision_function"]
)
@pytest.mark.parametrize("passthrough", [False, True])
def test_stacking_classifier_multilabel(stack_method, passthrough):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that we test too many things at one (handled by the if-else).

I would instead split the test into several smaller tests.

@pytest.mark.parametrize(
    "estimator",
    [
        # output a 2D array of the probability of the positive class for each output
        MLPClassifier(random_state=42),
        # output a list of 2D array containing the probability of each class for each output
        RandomForestClassifier(random_state=42),
    ],
    ids=["MLPClassifier", "RandomForestClassifier"],
)
def test_stacking_classifier_multilabel_predict_proba(estimator):
    """Check the behaviour for the multilabel classification case and the `predict_proba`
    stacking method.

    Estimators are not consistent with the output arrays and we need to ensure that
    we handle all cases.
    """
    X_train, X_test, y_train, y_test = train_test_split(
        X_multilabel, y_multilabel, stratify=y_multilabel, random_state=42
    )
    n_outputs = 3

    # MLPClassifier will return a 2D array where each column is the probability
    # of the positive class for each output. We stack this array directly without
    # any further processing.
    estimators = [("est", estimator)]
    stacker = StackingClassifier(
        estimators=estimators, final_estimator=KNeighborsClassifier(), stack_method="predict_proba",
    ).fit(X_train, y_train)

    X_trans = stacker.transform(X_test)
    assert X_trans.shape == (X_test.shape[0], n_outputs)
    # we should not have any collinear classes and thus nothing should sum to 1
    assert not any(np.isclose(X_trans.sum(axis=1), 1.0))

    y_pred = stacker.predict(X_test)
    assert y_pred.shape == y_pred.shape


def test_stacking_classifier_multilabel_decision_function():
    """Check the behaviour for the multilabel classification case and the
    `decision_function` stacking method. Only `RidgeClassifier` supports this
    case.
    """
    X_train, X_test, y_train, y_test = train_test_split(
        X_multilabel, y_multilabel, stratify=y_multilabel, random_state=42
    )
    n_outputs = 3

    estimators = [("est", RidgeClassifier())]
    stacker = StackingClassifier(
        estimators=estimators, final_estimator=KNeighborsClassifier(), stack_method="decision_function",
    ).fit(X_train, y_train)

    X_trans = stacker.transform(X_test)
    assert X_trans.shape == (X_test.shape[0], n_outputs)

    # check the shape consistency of the prediction
    y_pred = stacker.predict(X_test)
    assert y_pred.shape == y_pred.shape


@pytest.mark.parametrize("stack_method", ["auto", "predict"])
@pytest.mark.parametrize("passthrough", [False, True])
def test_stacking_classifier_multilabel_auto_predict(stack_method, passthrough):
    """Check the behaviour for the multilabel classification case for stack methods
    supported for all estimators or automatically picked up.
    """
    X_train, X_test, y_train, y_test = train_test_split(
        X_multilabel, y_multilabel, stratify=y_multilabel, random_state=42
    )
    y_train_before_fit = y_train.copy()
    n_outputs = 3

    estimators = [
        ("mlp", MLPClassifier(random_state=42)),
        ("rf", RandomForestClassifier(random_state=42)),
        ("ridge", RidgeClassifier()),
    ]
    final_estimator = KNeighborsClassifier()

    clf = StackingClassifier(
        estimators=estimators,
        final_estimator=final_estimator,
        passthrough=passthrough,
        stack_method=stack_method,
    ).fit(X_train, y_train)

    # make sure we don't change `y_train` inplace
    assert_array_equal(y_train_before_fit, y_train)

    y_pred = clf.predict(X_test)
    assert y_pred.shape == y_test.shape

    if stack_method == "auto":
        expected_stack_methods = ['predict_proba', 'predict_proba', 'decision_function']
    else:
        expected_stack_methods = ["predict"] * len(estimators)
    assert clf.stack_method_ == expected_stack_methods

    n_features_X_trans = n_outputs * len(estimators)
    if passthrough:
        n_features_X_trans += X_train.shape[1]
    X_trans = clf.transform(X_test)
    assert X_trans.shape == (X_test.shape[0], n_features_X_trans)

    assert_array_equal(clf.classes_, [np.array([0, 1])] * n_outputs)

"stack_method", ["auto", "predict", "predict_proba", "decision_function"]
)
@pytest.mark.parametrize("passthrough", [False, True])
def test_stacking_classifier_binary(stack_method, passthrough):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already covered the binary case in some tests above.
We could drop this test.

@Vincent-Maladiere
Copy link
Contributor Author

done, thanks for the review @glemaitre

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a couple of outdated stuff that I previously proposed :).
LGTM otherwise

it will drop one of the probability column when using probabilities
in the binary case. Indeed, the p(y|c=0) = 1 - p(y|c=1)
When `y` type is `"multilabel-indicator"` or `"multiclass-multioutput"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably remove the multiclass-multioutput mention for the moment.

in the binary case. Indeed, the p(y|c=0) = 1 - p(y|c=1)
When `y` type is `"multilabel-indicator"` or `"multiclass-multioutput"`
and the method used is `predict_proba`, `preds` can be either a ndarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and the method used is `predict_proba`, `preds` can be either a ndarray
and the method used is `predict_proba`, `preds` can be either a `ndarray`

When `y` type is `"multilabel-indicator"` or `"multiclass-multioutput"`
and the method used is `predict_proba`, `preds` can be either a ndarray
of shape (n_samples, n_class) or for some estimators a list of ndarray.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
of shape (n_samples, n_class) or for some estimators a list of ndarray.
of shape `(n_samples, n_class)` or for some estimators a list of `ndarray`.

self._le = LabelEncoder().fit(y)
self.classes_ = self._le.classes_
return super().fit(X, self._le.transform(y), sample_weight)
if type_of_target(y) in ("multilabel-indicator", "multiclass-multioutput"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we should restrain to multilabel-indicator here.

Comment on lines 680 to 682
# MLPClassifier will return a 2D array where each column is the probability
# of the positive class for each output. We stack this array directly without
# any further processing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we can remove this comment (it was before I parametrize the test). The info is already in the parametrization.

X_trans = stacker.transform(X_test)
assert X_trans.shape == (X_test.shape[0], n_outputs)

# check the shape consistency of the prediction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the assert is self explanatory here.

Suggested change
# check the shape consistency of the prediction

@glemaitre glemaitre added this to the 1.2 milestone Sep 7, 2022
@Vincent-Maladiere
Copy link
Contributor Author

Something looks wrong with CircleCI

@glemaitre
Copy link
Member

You can ignore it ;)

@jjerphan
Copy link
Member

jjerphan commented Sep 8, 2022

Random timeout/disconnection happens from time to time.

You can push an empty commit to re-trigger the CI if needed. To create an empty commit with this clear purpose, you can use:

git commit -m "Trigger CI" --allow-empty

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you, @Vincent-Maladiere!

Before merging, I just have one question regarding a conversation which I think can be resolved.

PS: once again, we can ignore the unrelated error on Circle CI.


if isinstance(self._label_encoder, list):
# Handle the multilabel-indicator and multiclass-multioutput cases
y_pred = np.array([preds[:, 0] for preds in y_pred]).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we resolve this conversation?

@Vincent-Maladiere
Copy link
Contributor Author

Hey @jjerphan, I simply removed the multiclass-multioutput mention from the comment, so that we don't account for it at the moment

@jjerphan
Copy link
Member

OK, merging then.

The Circle CI fail is unrelated.

@jjerphan jjerphan merged commit c18460f into scikit-learn:main Sep 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

StackingClassifier to support multilabel classification

4 participants