Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT parameter validation in Perceptron #23521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
acb5004
Simplify comparison
Nwanna-Joseph May 31, 2022
c9b3ae9
Merge branch 'scikit-learn:main' into main
Nwanna-Joseph Jun 2, 2022
863fcaf
impl
Nwanna-Joseph Jun 2, 2022
c10120a
setup _parameter_constraints
Nwanna-Joseph Jun 2, 2022
f22ab2a
bug fixes
Nwanna-Joseph Jun 2, 2022
7b7670b
_parameter_constraints for BaseSGD
Nwanna-Joseph Jun 2, 2022
e60d2e4
Merge branch 'scikit-learn:main' into main
Nwanna-Joseph Jun 3, 2022
8525f12
fixes
Nwanna-Joseph Jun 4, 2022
5744e4f
fix validation
Nwanna-Joseph Jun 4, 2022
07722e5
black linting
Nwanna-Joseph Jun 4, 2022
efd4a16
Merge branch 'scikit-learn:main' into main
Nwanna-Joseph Jun 5, 2022
489f060
fixes
Nwanna-Joseph Jun 5, 2022
8e61778
black lint
Nwanna-Joseph Jun 5, 2022
07d46ec
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
Nwanna-Joseph Jun 5, 2022
2ac300b
merge
Nwanna-Joseph Jun 5, 2022
76fcdf6
bug fix
Nwanna-Joseph Jun 6, 2022
8fabf2f
fix
Nwanna-Joseph Jun 6, 2022
93ee1a8
fix perceptron
Nwanna-Joseph Jun 6, 2022
1c17f73
black lint
Nwanna-Joseph Jun 6, 2022
302a329
fix tests
Nwanna-Joseph Jun 6, 2022
b85ffa4
black lint
Nwanna-Joseph Jun 6, 2022
45ef9a2
optimize imports
Nwanna-Joseph Jun 6, 2022
419bcd7
test fix
Nwanna-Joseph Jun 6, 2022
e29f513
black lint
Nwanna-Joseph Jun 6, 2022
5fe1785
clean up
Nwanna-Joseph Jun 6, 2022
e4dc9c2
clean up
Nwanna-Joseph Jun 6, 2022
1b538d1
clean up
Nwanna-Joseph Jun 6, 2022
c14ff81
Merge branch 'scikit-learn:main' into validate_perceptron
Nwanna-Joseph Jun 8, 2022
e25c1e7
Merge branch 'scikit-learn:main' into validate_perceptron
Nwanna-Joseph Jun 14, 2022
d217e5b
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
Nwanna-Joseph Jun 14, 2022
c400b7f
Merge remote-tracking branch 'origin/validate_perceptron' into valida…
Nwanna-Joseph Jun 14, 2022
ed78625
build fix
Nwanna-Joseph Jun 14, 2022
68509d4
build fix
Nwanna-Joseph Jun 14, 2022
3918206
black linting
Nwanna-Joseph Jun 14, 2022
f1f9435
fix bug
Nwanna-Joseph Jun 14, 2022
8efc7bc
fix bug
Nwanna-Joseph Jun 14, 2022
e58dcd8
fix bug
Nwanna-Joseph Jun 14, 2022
b8b7335
fixes
jeremiedbb Jun 24, 2022
a97734e
Merge remote-tracking branch 'upstream/main' into pr/Nwanna-Joseph/23521
jeremiedbb Jun 24, 2022
3bd2cc6
try to leverage inheritence + some fixes
jeremiedbb Jun 24, 2022
cc7a30f
cln
jeremiedbb Jun 24, 2022
c5cb17f
lint
jeremiedbb Jun 24, 2022
3e146dc
hidden for undocumented learning_rates
jeremiedbb Jun 24, 2022
782fad9
address review comments
jeremiedbb Jun 24, 2022
c14410d
Apply suggestions from code review
glemaitre Jun 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions sklearn/linear_model/_passive_aggressive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Authors: Rob Zinkov, Mathieu Blondel
# License: BSD 3 clause
from numbers import Real

from ._stochastic_gradient import BaseSGDClassifier
from ._stochastic_gradient import BaseSGDRegressor
from ._stochastic_gradient import DEFAULT_EPSILON
from ..utils._param_validation import Interval, StrOptions


class PassiveAggressiveClassifier(BaseSGDClassifier):
Expand Down Expand Up @@ -172,6 +174,12 @@ class PassiveAggressiveClassifier(BaseSGDClassifier):
[1]
"""

_parameter_constraints = {
**BaseSGDClassifier._parameter_constraints,
"loss": [StrOptions({"hinge", "squared_hinge"})],
"C": [Interval(Real, 0, None, closed="right")],
}

def __init__(
self,
*,
Expand Down Expand Up @@ -236,19 +244,23 @@ def partial_fit(self, X, y, classes=None):
self : object
Fitted estimator.
"""
self._validate_params(for_partial_fit=True)
if self.class_weight == "balanced":
raise ValueError(
"class_weight 'balanced' is not supported for "
"partial_fit. For 'balanced' weights, use "
"`sklearn.utils.compute_class_weight` with "
"`class_weight='balanced'`. In place of y you "
"can use a large enough subset of the full "
"training set target to properly estimate the "
"class frequency distributions. Pass the "
"resulting weights as the class_weight "
"parameter."
)
if not hasattr(self, "classes_"):
self._validate_params()
self._more_validate_params(for_partial_fit=True)

if self.class_weight == "balanced":
raise ValueError(
"class_weight 'balanced' is not supported for "
"partial_fit. For 'balanced' weights, use "
"`sklearn.utils.compute_class_weight` with "
"`class_weight='balanced'`. In place of y you "
"can use a large enough subset of the full "
"training set target to properly estimate the "
"class frequency distributions. Pass the "
"resulting weights as the class_weight "
"parameter."
)

lr = "pa1" if self.loss == "hinge" else "pa2"
return self._partial_fit(
X,
Expand Down Expand Up @@ -287,6 +299,8 @@ def fit(self, X, y, coef_init=None, intercept_init=None):
Fitted estimator.
"""
self._validate_params()
self._more_validate_params()

lr = "pa1" if self.loss == "hinge" else "pa2"
return self._fit(
X,
Expand Down Expand Up @@ -445,6 +459,13 @@ class PassiveAggressiveRegressor(BaseSGDRegressor):
[-0.02306214]
"""

_parameter_constraints = {
**BaseSGDRegressor._parameter_constraints,
"loss": [StrOptions({"epsilon_insensitive", "squared_epsilon_insensitive"})],
"C": [Interval(Real, 0, None, closed="right")],
"epsilon": [Interval(Real, 0, None, closed="left")],
}

def __init__(
self,
*,
Expand Down Expand Up @@ -499,7 +520,10 @@ def partial_fit(self, X, y):
self : object
Fitted estimator.
"""
self._validate_params(for_partial_fit=True)
if not hasattr(self, "coef_"):
self._validate_params()
self._more_validate_params(for_partial_fit=True)

lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2"
return self._partial_fit(
X,
Expand Down Expand Up @@ -537,6 +561,8 @@ def fit(self, X, y, coef_init=None, intercept_init=None):
Fitted estimator.
"""
self._validate_params()
self._more_validate_params()

lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2"
return self._fit(
X,
Expand Down
16 changes: 15 additions & 1 deletion sklearn/linear_model/_perceptron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Author: Mathieu Blondel
# License: BSD 3 clause
from numbers import Real

from ._stochastic_gradient import BaseSGDClassifier
from ..utils._param_validation import StrOptions, Interval


class Perceptron(BaseSGDClassifier):
Expand Down Expand Up @@ -37,7 +39,7 @@ class Perceptron(BaseSGDClassifier):

.. versionadded:: 0.19

tol : float, default=1e-3
tol : float or None, default=1e-3
The stopping criterion. If it is not None, the iterations will stop
when (loss > previous_loss - tol).

Expand Down Expand Up @@ -164,6 +166,18 @@ class Perceptron(BaseSGDClassifier):
0.939...
"""

_parameter_constraints = {**BaseSGDClassifier._parameter_constraints}
_parameter_constraints.pop("loss")
_parameter_constraints.pop("average")
_parameter_constraints.update(
{
"penalty": [StrOptions({"l2", "l1", "elasticnet"}), None],
"alpha": [Interval(Real, 0, None, closed="left")],
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
"eta0": [Interval(Real, 0, None, closed="left")],
}
)

def __init__(
self,
*,
Expand Down
Loading