From 8d4b501072a966878f6c144273633b7559d10d42 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Sun, 3 Apr 2022 22:11:49 -0400 Subject: [PATCH 1/9] Update _alpha_grid to take sample_weight It seems like this single call to _preprocess_data suffices in all cases. --- sklearn/linear_model/_coordinate_descent.py | 64 +++++++++------------ 1 file changed, 28 insertions(+), 36 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index b450790da5a07..bc338ebaa41db 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -92,6 +92,7 @@ def _alpha_grid( n_alphas=100, normalize=False, copy_X=True, + sample_weight=None, ): """Compute the grid of alpha values for elastic net parameter search @@ -138,6 +139,8 @@ def _alpha_grid( copy_X : bool, default=True If ``True``, X will be copied; else, it may be overwritten. + + sample_weight : ndarray of shape (n_samples,) """ if l1_ratio == 0: raise ValueError( @@ -146,47 +149,35 @@ def _alpha_grid( "your estimator with the appropriate `alphas=` " "argument." ) - n_samples = len(y) - - sparse_center = False - if Xy is None: - X_sparse = sparse.isspmatrix(X) - sparse_center = X_sparse and (fit_intercept or normalize) - X = check_array( - X, accept_sparse="csc", copy=(copy_X and fit_intercept and not X_sparse) + if Xy is not None and ((sample_weight is None) or (fit_intercept is False)): + # In this case, the precomputed Xy should be valid. + pass + else: + # Compute Xy. + X, y, X_offset, y_offset, X_scale = _preprocess_data( + X, + y, + fit_intercept, + normalize=normalize, + copy=copy_X, + sample_weight=sample_weight, + check_input=False, ) - if not X_sparse: - # X can be touched inplace thanks to the above line - X, y, _, _, _ = _preprocess_data(X, y, fit_intercept, normalize, copy=False) - Xy = safe_sparse_dot(X.T, y, dense_output=True) - - if sparse_center: - # Workaround to find alpha_max for sparse matrices. - # since we should not destroy the sparsity of such matrices. - _, _, X_offset, _, X_scale = _preprocess_data( - X, y, fit_intercept, normalize - ) - mean_dot = X_offset * np.sum(y) + if sample_weight is not None: + yw = y * sample_weight / sample_weight.mean() + else: + yw = y + if sparse.issparse(X): + Xy = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset + else: + Xy = np.dot(X.T, yw) if Xy.ndim == 1: Xy = Xy[:, np.newaxis] - - if sparse_center: - if fit_intercept: - Xy -= mean_dot[:, np.newaxis] - if normalize: - Xy /= X_scale[:, np.newaxis] - - alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (n_samples * l1_ratio) - + alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (l1_ratio * X.shape[0]) if alpha_max <= np.finfo(float).resolution: - alphas = np.empty(n_alphas) - alphas.fill(np.finfo(float).resolution) - return alphas - - return np.logspace(np.log10(alpha_max * eps), np.log10(alpha_max), num=n_alphas)[ - ::-1 - ] + return np.full(n_alphas, np.finfo(float).resolution) + return np.logspace(np.log10(alpha_max), np.log10(alpha_max * eps), num=n_alphas) def lasso_path( @@ -1660,6 +1651,7 @@ def fit(self, X, y, sample_weight=None): n_alphas=self.n_alphas, normalize=_normalize, copy_X=self.copy_X, + sample_weight=sample_weight, ) for l1_ratio in l1_ratios ] From 2f494db40f647ca99f57807a3934943caecc7d58 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Sun, 3 Apr 2022 22:12:51 -0400 Subject: [PATCH 2/9] Add a simple test for alpha_max with sample_weight This tiny example was given in https://github.com/scikit-learn/scikit-learn/issues/22914. The test merely asserts that alpha_max is large enough to force the coefficient to 0. --- .../linear_model/tests/test_coordinate_descent.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index e5d7ba358c1f5..f38fd3933371b 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -1642,6 +1642,18 @@ def test_enet_cv_grid_search(sample_weight): assert reg.alpha_ == pytest.approx(gs.best_params_["alpha"]) +@pytest.mark.parametrize("sparseX", [False, True]) +def test_enet_alpha_max_sample_weight(sparseX): + X = np.array([[3, 1], [2, 5], [5, 3], [1, 4]]) + beta = np.array([1, 1]) + y = X @ beta + sample_weight = np.array([10, 1, 10, 1]) + if sparseX: + X = sparse.csc_matrix(X) + reg = ElasticNetCV(n_alphas=1, cv=2).fit(X, y, sample_weight=sample_weight) + assert_almost_equal(reg.coef_, 0) + + @pytest.mark.parametrize("fit_intercept", [True, False]) @pytest.mark.parametrize("l1_ratio", [0, 0.5, 1]) @pytest.mark.parametrize("precompute", [False, True]) From fa2c8215bf61e1f553dcfed3dae539a3769a3b00 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Mon, 4 Apr 2022 09:06:16 -0400 Subject: [PATCH 3/9] Update test As per reviewer's suggestions: (1) Clarify eps=1. (2) Parameterize `fit_intercept`. --- sklearn/linear_model/tests/test_coordinate_descent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index f38fd3933371b..807429985dcac 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -1643,14 +1643,16 @@ def test_enet_cv_grid_search(sample_weight): @pytest.mark.parametrize("sparseX", [False, True]) -def test_enet_alpha_max_sample_weight(sparseX): +@pytest.mark.parametrize("fit_intercept", [False, True]) +def test_enet_alpha_max_sample_weight(sparseX, fit_intercept): X = np.array([[3, 1], [2, 5], [5, 3], [1, 4]]) beta = np.array([1, 1]) y = X @ beta sample_weight = np.array([10, 1, 10, 1]) if sparseX: X = sparse.csc_matrix(X) - reg = ElasticNetCV(n_alphas=1, cv=2).fit(X, y, sample_weight=sample_weight) + reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept) + reg.fit(X, y, sample_weight=sample_weight) assert_almost_equal(reg.coef_, 0) From 75e65842e9312af3296a96ccab7d38dfbb0b59e6 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Mon, 4 Apr 2022 09:10:45 -0400 Subject: [PATCH 4/9] Clarify _alpha_grid. (1) Give the name `n_samples` to the quantity `X.shape[0]`. (2) Clarify that `y_offset` and `X_scale` are not used, since these are already applied to the data by `_preprocess_data`. --- sklearn/linear_model/_coordinate_descent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index bc338ebaa41db..6c710af8b5137 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -154,7 +154,7 @@ def _alpha_grid( pass else: # Compute Xy. - X, y, X_offset, y_offset, X_scale = _preprocess_data( + X, y, X_offset, _, _ = _preprocess_data( X, y, fit_intercept, @@ -174,7 +174,9 @@ def _alpha_grid( if Xy.ndim == 1: Xy = Xy[:, np.newaxis] - alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (l1_ratio * X.shape[0]) + n_samples = X.shape[0] + alpha_max = np.max(np.sqrt(np.sum(Xy**2, axis=1))) / (l1_ratio * n_samples) + if alpha_max <= np.finfo(float).resolution: return np.full(n_alphas, np.finfo(float).resolution) return np.logspace(np.log10(alpha_max), np.log10(alpha_max * eps), num=n_alphas) From 8b6cfc0bfb53faf56cd6e410e35b4b83ebc3f219 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Tue, 5 Apr 2022 18:16:49 -0400 Subject: [PATCH 5/9] Clarify notation --- sklearn/linear_model/_coordinate_descent.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 6c710af8b5137..3655bd667c2d0 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -150,10 +150,8 @@ def _alpha_grid( "argument." ) if Xy is not None and ((sample_weight is None) or (fit_intercept is False)): - # In this case, the precomputed Xy should be valid. - pass + Xyw = Xy else: - # Compute Xy. X, y, X_offset, _, _ = _preprocess_data( X, y, @@ -168,14 +166,14 @@ def _alpha_grid( else: yw = y if sparse.issparse(X): - Xy = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset + Xyw = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset else: - Xy = np.dot(X.T, yw) + Xyw = np.dot(X.T, yw) - if Xy.ndim == 1: - Xy = Xy[:, np.newaxis] + if Xyw.ndim == 1: + Xyw = Xyw[:, np.newaxis] n_samples = X.shape[0] - alpha_max = np.max(np.sqrt(np.sum(Xy**2, axis=1))) / (l1_ratio * n_samples) + alpha_max = np.max(np.sqrt(np.sum(Xyw**2, axis=1))) / (l1_ratio * n_samples) if alpha_max <= np.finfo(float).resolution: return np.full(n_alphas, np.finfo(float).resolution) From 2ba4c57acd36b7a6c92b7322ce0cb9fc18cd34ce Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Sat, 2 Jul 2022 09:51:25 -0400 Subject: [PATCH 6/9] Use Xy if it is provided. --- sklearn/linear_model/_coordinate_descent.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 3655bd667c2d0..b1597bedce814 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -149,7 +149,7 @@ def _alpha_grid( "your estimator with the appropriate `alphas=` " "argument." ) - if Xy is not None and ((sample_weight is None) or (fit_intercept is False)): + if Xy is not None: Xyw = Xy else: X, y, X_offset, _, _ = _preprocess_data( @@ -162,7 +162,7 @@ def _alpha_grid( check_input=False, ) if sample_weight is not None: - yw = y * sample_weight / sample_weight.mean() + yw = y * sample_weight else: yw = y if sparse.issparse(X): @@ -172,7 +172,10 @@ def _alpha_grid( if Xyw.ndim == 1: Xyw = Xyw[:, np.newaxis] - n_samples = X.shape[0] + if sample_weight is not None: + n_samples = sample_weight.sum() + else: + n_samples = X.shape[0] alpha_max = np.max(np.sqrt(np.sum(Xyw**2, axis=1))) / (l1_ratio * n_samples) if alpha_max <= np.finfo(float).resolution: From 5d1f5e7b09aa08cdf5b7d980ef114bcafb9f8c52 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Sat, 2 Jul 2022 10:08:08 -0400 Subject: [PATCH 7/9] Update test, check alpha_max is not too large --- sklearn/linear_model/tests/test_coordinate_descent.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 807429985dcac..8fdc9dbc43c96 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -26,6 +26,7 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal +from sklearn.utils._testing import assert_array_less from sklearn.utils._testing import ignore_warnings from sklearn.utils._testing import _convert_container @@ -1651,9 +1652,14 @@ def test_enet_alpha_max_sample_weight(sparseX, fit_intercept): sample_weight = np.array([10, 1, 10, 1]) if sparseX: X = sparse.csc_matrix(X) + # Test alpha_max makes coefs zero. reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept) reg.fit(X, y, sample_weight=sample_weight) assert_almost_equal(reg.coef_, 0) + # Test smaller alpha makes coefs nonzero. + reg = ElasticNetCV(n_alphas=2, cv=2, eps=0.99, fit_intercept=fit_intercept) + reg.fit(X, y, sample_weight=sample_weight) + assert_array_less(0, np.max(np.abs(reg.coef_))) @pytest.mark.parametrize("fit_intercept", [True, False]) From dce169c41c4613e50fbf61ac18883722edb47ffe Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Sat, 2 Jul 2022 12:43:15 -0400 Subject: [PATCH 8/9] Fix test that alpha_max is not too large. --- sklearn/linear_model/tests/test_coordinate_descent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 8fdc9dbc43c96..a414a14c71c06 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -1656,10 +1656,11 @@ def test_enet_alpha_max_sample_weight(sparseX, fit_intercept): reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept) reg.fit(X, y, sample_weight=sample_weight) assert_almost_equal(reg.coef_, 0) + alpha_max = reg.alpha_ # Test smaller alpha makes coefs nonzero. - reg = ElasticNetCV(n_alphas=2, cv=2, eps=0.99, fit_intercept=fit_intercept) + reg = ElasticNet(alpha=0.99 * alpha_max, fit_intercept=fit_intercept) reg.fit(X, y, sample_weight=sample_weight) - assert_array_less(0, np.max(np.abs(reg.coef_))) + assert_array_less(1e-3, np.max(np.abs(reg.coef_))) @pytest.mark.parametrize("fit_intercept", [True, False]) From 380c21f6f342b69021b0f5269e919c2cb82980c9 Mon Sep 17 00:00:00 2001 From: "Mr. Snrub" <45150804+s-banach@users.noreply.github.com> Date: Tue, 5 Jul 2022 22:39:17 -0400 Subject: [PATCH 9/9] Test alpha_max without sample_weight. --- sklearn/linear_model/tests/test_coordinate_descent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index a414a14c71c06..c7c4a839b7170 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -1645,11 +1645,11 @@ def test_enet_cv_grid_search(sample_weight): @pytest.mark.parametrize("sparseX", [False, True]) @pytest.mark.parametrize("fit_intercept", [False, True]) -def test_enet_alpha_max_sample_weight(sparseX, fit_intercept): - X = np.array([[3, 1], [2, 5], [5, 3], [1, 4]]) +@pytest.mark.parametrize("sample_weight", [np.array([10, 1, 10, 1]), None]) +def test_enet_alpha_max_sample_weight(sparseX, fit_intercept, sample_weight): + X = np.array([[3.0, 1.0], [2.0, 5.0], [5.0, 3.0], [1.0, 4.0]]) beta = np.array([1, 1]) y = X @ beta - sample_weight = np.array([10, 1, 10, 1]) if sparseX: X = sparse.csc_matrix(X) # Test alpha_max makes coefs zero.