Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] More deprecations for 0.23 #15860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5dd561a
removed warn_on_dtype
NicolasHug Dec 5, 2019
e05e17a
removed parameters to check_is_fitted
NicolasHug Dec 5, 2019
cdfac1e
all_estimators parameters
NicolasHug Dec 5, 2019
ef5d570
deprecated n_components attribute in AgglomerativeClustering
NicolasHug Dec 5, 2019
5485edb
Merge branch 'master' of github.com:scikit-learn/scikit-learn into de…
NicolasHug Dec 5, 2019
6671682
change default of base.score for multioutput
NicolasHug Dec 5, 2019
38b97cb
Merge branch 'multioutput_dep' into dep023
NicolasHug Dec 5, 2019
b5fe811
removed lots of useless decorators?
NicolasHug Dec 5, 2019
5304343
changed default of copy in quantil_transform
NicolasHug Dec 5, 2019
226db87
removed six.py
NicolasHug Dec 5, 2019
53f9ecc
nmf default value of init param
NicolasHug Dec 5, 2019
d80940a
raise error instead of warning in LinearDiscriminantAnalysis
NicolasHug Dec 5, 2019
16b3c9c
removed label param in hamming_loss
NicolasHug Dec 5, 2019
7af6207
updated method parameter of power_transform
NicolasHug Dec 5, 2019
808ab05
pep8
NicolasHug Dec 5, 2019
0d574a0
changed default value of min_impurity_split
NicolasHug Dec 5, 2019
5a4c2d5
removed assert_false and assert_true
NicolasHug Dec 5, 2019
887edd7
Merge branch 'master' of github.com:scikit-learn/scikit-learn into de…
NicolasHug Dec 9, 2019
04ec379
added and fixed versionchanged directives
NicolasHug Dec 9, 2019
015ad40
reset min_impurity_split default to None
NicolasHug Dec 9, 2019
e6443a5
fixed LDA issue
NicolasHug Dec 9, 2019
09bf4e5
fixed some test
NicolasHug Dec 9, 2019
1fae94f
more docstrings updates
NicolasHug Dec 9, 2019
43fea84
set min_impurity_decrease for test to pass
NicolasHug Dec 9, 2019
7cd20a0
upate docstring example
NicolasHug Dec 9, 2019
7fb0872
fixed doctest
NicolasHug Dec 9, 2019
dec2847
Merge branch 'master' of github.com:scikit-learn/scikit-learn into de…
NicolasHug Dec 11, 2019
bfa24b0
removed multiouput.score since it's now consistent with the default
NicolasHug Dec 11, 2019
930d64a
deprecate least_angle parameter combination
NicolasHug Dec 11, 2019
95c5ac1
remove support for l1 or l2 loss in svm
NicolasHug Dec 11, 2019
a9efc85
removed linear_assignment.py
NicolasHug Dec 11, 2019
3c1f5fb
Merge branch 'master' of github.com:scikit-learn/scikit-learn into de…
NicolasHug Dec 13, 2019
e917a5a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into de…
NicolasHug Jan 8, 2020
2015a16
add test
NicolasHug Jan 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions sklearn/linear_model/_least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, alpha_min=0,
Input data. Note that if X is None then the Gram matrix must be
specified, i.e., cannot be None or False.

.. deprecated:: 0.21

The use of ``X`` is ``None`` in combination with ``Gram`` is not
``None`` will be removed in v0.23. Use :func:`lars_path_gram`
instead.

y : None or array-like of shape (n_samples,)
Input targets.

Expand All @@ -67,11 +61,6 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, alpha_min=0,
matrix is precomputed from the given X, if there are more samples
than features.

.. deprecated:: 0.21

The use of ``X`` is ``None`` in combination with ``Gram`` is not
None will be removed in v0.23. Use :func:`lars_path_gram` instead.

max_iter : int, default=500
Maximum number of iterations to perform, set to infinity for no limit.

Expand Down Expand Up @@ -155,9 +144,10 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, alpha_min=0,

"""
if X is None and Gram is not None:
warnings.warn('Use lars_path_gram to avoid passing X and y. '
'The current option will be removed in v0.23.',
FutureWarning)
raise ValueError(
'X cannot be None if Gram is not None'
'Use lars_path_gram to avoid passing X and y.'
)
return _lars_path_solver(
X=X, y=y, Xy=Xy, Gram=Gram, n_samples=None, max_iter=max_iter,
alpha_min=alpha_min, method=method, copy_X=copy_X,
Expand Down
9 changes: 8 additions & 1 deletion sklearn/linear_model/tests/test_least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from sklearn.utils._testing import TempMemmap
from sklearn.exceptions import ConvergenceWarning
from sklearn import linear_model, datasets
from sklearn.linear_model._least_angle import _lars_path_residues, LassoLarsIC
from sklearn.linear_model._least_angle import _lars_path_residues
from sklearn.linear_model import LassoLarsIC, lars_path

# TODO: use another dataset that has multiple drops
diabetes = datasets.load_diabetes()
Expand Down Expand Up @@ -730,3 +731,9 @@ def test_lasso_lars_fit_copyX_behaviour(copy_X):
y = X[:, 2]
lasso_lars.fit(X, y, copy_X=copy_X)
assert copy_X == np.array_equal(X, X_copy)


def test_X_none_gram_not_none():
with pytest.raises(ValueError,
match="X cannot be None if Gram is not None"):
lars_path(X=None, y=[1], Gram='not None')
38 changes: 0 additions & 38 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,44 +265,6 @@ def partial_fit(self, X, y, sample_weight=None):
super().partial_fit(
X, y, sample_weight=sample_weight)

# XXX Remove this method in 0.23
def score(self, X, y, sample_weight=None):
"""Returns the coefficient of determination R^2 of the prediction.

The coefficient R^2 is defined as (1 - u/v), where u is the residual
sum of squares ((y_true - y_pred) ** 2).sum() and v is the regression
sum of squares ((y_true - y_true.mean()) ** 2).sum().
Best possible score is 1.0 and it can be negative (because the
model can be arbitrarily worse). A constant model that always
predicts the expected value of y, disregarding the input features,
would get a R^2 score of 0.0.

Notes
-----
R^2 is calculated by weighting all the targets equally using
`multioutput='uniform_average'`.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Test samples.

y : array-like, shape (n_samples) or (n_samples, n_outputs)
True values for X.

sample_weight : array-like, shape [n_samples], optional
Sample weights.

Returns
-------
score : float
R^2 of self.predict(X) wrt. y.
"""
# XXX remove in 0.19 when r2_score default for multioutput changes
from .metrics import r2_score
return r2_score(y, self.predict(X), sample_weight=sample_weight,
multioutput='uniform_average')


class MultiOutputClassifier(ClassifierMixin, _MultiOutputEstimator):
"""Multi target classification
Expand Down
26 changes: 0 additions & 26 deletions sklearn/svm/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,6 @@ def fit(self, X, y, sample_weight=None):
self : object
An instance of the estimator.
"""
# FIXME Remove l1/l2 support in 0.23 ----------------------------------
msg = ("loss='%s' has been deprecated in favor of "
"loss='%s' as of 0.16. Backward compatibility"
" for the loss='%s' will be removed in %s")

if self.loss in ('l1', 'l2'):
old_loss = self.loss
self.loss = {'l1': 'hinge', 'l2': 'squared_hinge'}.get(self.loss)
warnings.warn(msg % (old_loss, self.loss, old_loss, '0.23'),
FutureWarning)
# ---------------------------------------------------------------------

if self.C < 0:
raise ValueError("Penalty term must be positive; got (C=%r)"
% self.C)
Expand Down Expand Up @@ -403,20 +391,6 @@ def fit(self, X, y, sample_weight=None):
-------
self : object
"""
# FIXME Remove l1/l2 support in 0.23 ----------------------------------
msg = ("loss='%s' has been deprecated in favor of "
"loss='%s' as of 0.16. Backward compatibility"
" for the loss='%s' will be removed in %s")

if self.loss in ('l1', 'l2'):
old_loss = self.loss
self.loss = {'l1': 'epsilon_insensitive',
'l2': 'squared_epsilon_insensitive'
}.get(self.loss)
warnings.warn(msg % (old_loss, self.loss, old_loss, '0.23'),
FutureWarning)
# ---------------------------------------------------------------------

if self.C < 0:
raise ValueError("Penalty term must be positive; got (C=%r)"
% self.C)
Expand Down
33 changes: 0 additions & 33 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,39 +736,6 @@ def test_linearsvc_parameters():
svm.LinearSVC(loss="l3").fit(X, y)


# FIXME remove in 0.23
def test_linearsvx_loss_penalty_deprecations():
X, y = [[0.0], [1.0]], [0, 1]

msg = ("loss='%s' has been deprecated in favor of "
"loss='%s' as of 0.16. Backward compatibility"
" for the %s will be removed in %s")

# LinearSVC
# loss l1 --> hinge
assert_warns_message(FutureWarning,
msg % ("l1", "hinge", "loss='l1'", "0.23"),
svm.LinearSVC(loss="l1").fit, X, y)

# loss l2 --> squared_hinge
assert_warns_message(FutureWarning,
msg % ("l2", "squared_hinge", "loss='l2'", "0.23"),
svm.LinearSVC(loss="l2").fit, X, y)

# LinearSVR
# loss l1 --> epsilon_insensitive
assert_warns_message(FutureWarning,
msg % ("l1", "epsilon_insensitive", "loss='l1'",
"0.23"),
svm.LinearSVR(loss="l1").fit, X, y)

# loss l2 --> squared_epsilon_insensitive
assert_warns_message(FutureWarning,
msg % ("l2", "squared_epsilon_insensitive",
"loss='l2'", "0.23"),
svm.LinearSVR(loss="l2").fit, X, y)


def test_linear_svx_uppercase_loss_penality_raises_error():
# Check if Upper case notation raises error at _fit_liblinear
# which is called by fit
Expand Down
Loading