Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add sample_weight to the calculation of alphas in enet_path and LinearModelCV #23045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _alpha_grid(
n_alphas=100,
normalize=False,
copy_X=True,
sample_weight=None,
):
"""Compute the grid of alpha values for elastic net parameter search

Expand Down Expand Up @@ -138,6 +139,8 @@ def _alpha_grid(

copy_X : bool, default=True
If ``True``, X will be copied; else, it may be overwritten.

sample_weight : ndarray of shape (n_samples,)
"""
if l1_ratio == 0:
raise ValueError(
Expand All @@ -146,47 +149,38 @@ def _alpha_grid(
"your estimator with the appropriate `alphas=` "
"argument."
)
n_samples = len(y)

sparse_center = False
if Xy is None:
X_sparse = sparse.isspmatrix(X)
sparse_center = X_sparse and (fit_intercept or normalize)
X = check_array(
X, accept_sparse="csc", copy=(copy_X and fit_intercept and not X_sparse)
if Xy is not None:
Xyw = Xy
else:
X, y, X_offset, _, _ = _preprocess_data(
X,
y,
fit_intercept,
normalize=normalize,
copy=copy_X,
sample_weight=sample_weight,
check_input=False,
)
if not X_sparse:
# X can be touched inplace thanks to the above line
X, y, _, _, _ = _preprocess_data(X, y, fit_intercept, normalize, copy=False)
Xy = safe_sparse_dot(X.T, y, dense_output=True)

if sparse_center:
# Workaround to find alpha_max for sparse matrices.
# since we should not destroy the sparsity of such matrices.
_, _, X_offset, _, X_scale = _preprocess_data(
X, y, fit_intercept, normalize
)
mean_dot = X_offset * np.sum(y)

if Xy.ndim == 1:
Xy = Xy[:, np.newaxis]

if sparse_center:
if fit_intercept:
Xy -= mean_dot[:, np.newaxis]
if normalize:
Xy /= X_scale[:, np.newaxis]
if sample_weight is not None:
yw = y * sample_weight
else:
yw = y
if sparse.issparse(X):
Xyw = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset
else:
Xyw = np.dot(X.T, yw)

alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (n_samples * l1_ratio)
if Xyw.ndim == 1:
Xyw = Xyw[:, np.newaxis]
if sample_weight is not None:
n_samples = sample_weight.sum()
else:
n_samples = X.shape[0]
alpha_max = np.max(np.sqrt(np.sum(Xyw**2, axis=1))) / (l1_ratio * n_samples)

if alpha_max <= np.finfo(float).resolution:
alphas = np.empty(n_alphas)
alphas.fill(np.finfo(float).resolution)
return alphas

return np.logspace(np.log10(alpha_max * eps), np.log10(alpha_max), num=n_alphas)[
::-1
]
return np.full(n_alphas, np.finfo(float).resolution)
return np.logspace(np.log10(alpha_max), np.log10(alpha_max * eps), num=n_alphas)


def lasso_path(
Expand Down Expand Up @@ -1660,6 +1654,7 @@ def fit(self, X, y, sample_weight=None):
n_alphas=self.n_alphas,
normalize=_normalize,
copy_X=self.copy_X,
sample_weight=sample_weight,
)
for l1_ratio in l1_ratios
]
Expand Down
21 changes: 21 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_less
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import _convert_container

Expand Down Expand Up @@ -1642,6 +1643,26 @@ def test_enet_cv_grid_search(sample_weight):
assert reg.alpha_ == pytest.approx(gs.best_params_["alpha"])


@pytest.mark.parametrize("sparseX", [False, True])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("sample_weight", [np.array([10, 1, 10, 1]), None])
def test_enet_alpha_max_sample_weight(sparseX, fit_intercept, sample_weight):
X = np.array([[3.0, 1.0], [2.0, 5.0], [5.0, 3.0], [1.0, 4.0]])
beta = np.array([1, 1])
y = X @ beta
if sparseX:
X = sparse.csc_matrix(X)
# Test alpha_max makes coefs zero.
reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept)
reg.fit(X, y, sample_weight=sample_weight)
assert_almost_equal(reg.coef_, 0)
alpha_max = reg.alpha_
# Test smaller alpha makes coefs nonzero.
reg = ElasticNet(alpha=0.99 * alpha_max, fit_intercept=fit_intercept)
reg.fit(X, y, sample_weight=sample_weight)
assert_array_less(1e-3, np.max(np.abs(reg.coef_)))


@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("l1_ratio", [0, 0.5, 1])
@pytest.mark.parametrize("precompute", [False, True])
Expand Down