Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1a33e56

Browse files
FIX Improves error message in partial_fit when early_stopping=True (#25694)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent cab7256 commit 1a33e56

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ Changelog
204204
no longer raise warnings when fitting data with feature names.
205205
:pr:`24873` by :user:`Tim Head <betatim>`.
206206

207+
- |Fix| Improves error message in :class:`neural_network.MLPClassifier` and
208+
:class:`neural_network.MLPRegressor`, when `early_stopping=True` and
209+
:meth:`partial_fit` is called. :pr:`25694` by `Thomas Fan`_.
210+
207211
:mod:`sklearn.preprocessing`
208212
............................
209213

sklearn/neural_network/_multilayer_perceptron.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,9 @@ def _fit_stochastic(
575575
)
576576

577577
# early_stopping in partial_fit doesn't make sense
578-
early_stopping = self.early_stopping and not incremental
578+
if self.early_stopping and incremental:
579+
raise ValueError("partial_fit does not support early_stopping=True")
580+
early_stopping = self.early_stopping
579581
if early_stopping:
580582
# don't stratify in multilabel classification
581583
should_stratify = is_classifier(self) and self.n_outputs_ == 1

sklearn/neural_network/tests/test_mlp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,3 +953,16 @@ def test_mlp_warm_start_no_convergence(MLPEstimator, solver):
953953
with pytest.warns(ConvergenceWarning):
954954
model.fit(X_iris, y_iris)
955955
assert model.n_iter_ == 20
956+
957+
958+
@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
959+
def test_mlp_partial_fit_after_fit(MLPEstimator):
960+
"""Check partial fit does not fail after fit when early_stopping=True.
961+
962+
Non-regression test for gh-25693.
963+
"""
964+
mlp = MLPEstimator(early_stopping=True, random_state=0).fit(X_iris, y_iris)
965+
966+
msg = "partial_fit does not support early_stopping=True"
967+
with pytest.raises(ValueError, match=msg):
968+
mlp.partial_fit(X_iris, y_iris)

0 commit comments

Comments
 (0)