From 12c8c07ac6e04045d7edf1917f62001c7e70a295 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Aug 2025 15:29:49 +0200 Subject: [PATCH 1/8] MNT renaming of variables --- sklearn/linear_model/_cd_fast.pyx | 160 +++++++++++++++--------------- 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 369ab162d563c..92a4122dad484 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -450,7 +450,7 @@ def sparse_enet_coordinate_descent( # We work with: # yw = sample_weight * y # R = sample_weight * residual - # norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0) + # norm2_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0) if floating is float: dtype = np.float32 @@ -462,7 +462,7 @@ def sparse_enet_coordinate_descent( cdef unsigned int n_features = w.shape[0] # compute norms of the columns of X - cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype) + cdef floating[::1] norm2_cols_X = np.zeros(n_features, dtype=dtype) # initial value of the residuals # R = y - Zw, weighted version R = sample_weight * (y - Zw) @@ -471,14 +471,14 @@ def sparse_enet_coordinate_descent( cdef const floating[::1] yw cdef floating tmp - cdef floating w_ii + cdef floating w_j cdef floating d_w_max cdef floating w_max - cdef floating d_w_ii + cdef floating d_w_j cdef floating gap = tol + 1.0 cdef floating d_w_tol = tol cdef floating dual_norm_XtA - cdef floating X_mean_ii + cdef floating X_mean_j cdef floating R_sum = 0.0 cdef floating R_norm2 cdef floating w_norm2 @@ -486,8 +486,8 @@ def sparse_enet_coordinate_descent( cdef floating const_ cdef floating A_norm2 cdef floating normalize_sum - cdef unsigned int ii - cdef unsigned int jj + cdef unsigned int i + cdef unsigned int j cdef unsigned int n_iter = 0 cdef unsigned int f_iter cdef unsigned int startptr = X_indptr[0] @@ -507,46 +507,46 @@ def sparse_enet_coordinate_descent( with nogil: # center = (X_mean != 0).any() - for ii in range(n_features): - if X_mean[ii]: + for j in range(n_features): + if X_mean[j]: center = True break # R = y - np.dot(X, w) - for ii in range(n_features): - X_mean_ii = X_mean[ii] - endptr = X_indptr[ii + 1] + for j in range(n_features): + X_mean_j = X_mean[j] + endptr = X_indptr[j + 1] normalize_sum = 0.0 - w_ii = w[ii] + w_j = w[j] if no_sample_weights: - for jj in range(startptr, endptr): - normalize_sum += (X_data[jj] - X_mean_ii) ** 2 - R[X_indices[jj]] -= X_data[jj] * w_ii - norm_cols_X[ii] = normalize_sum + \ - (n_samples - endptr + startptr) * X_mean_ii ** 2 + for i in range(startptr, endptr): + normalize_sum += (X_data[i] - X_mean_j) ** 2 + R[X_indices[i]] -= X_data[i] * w_j + norm2_cols_X[j] = normalize_sum + \ + (n_samples - endptr + startptr) * X_mean_j ** 2 if center: - for jj in range(n_samples): - R[jj] += X_mean_ii * w_ii - R_sum += R[jj] + for i in range(n_samples): + R[i] += X_mean_j * w_j + R_sum += R[i] else: # R = sw * (y - np.dot(X, w)) - for jj in range(startptr, endptr): - tmp = sample_weight[X_indices[jj]] + for i in range(startptr, endptr): + tmp = sample_weight[X_indices[i]] # second term will be subtracted by loop over range(n_samples) - normalize_sum += (tmp * (X_data[jj] - X_mean_ii) ** 2 - - tmp * X_mean_ii ** 2) - R[X_indices[jj]] -= tmp * X_data[jj] * w_ii + normalize_sum += (tmp * (X_data[i] - X_mean_j) ** 2 + - tmp * X_mean_j ** 2) + R[X_indices[i]] -= tmp * X_data[i] * w_j if center: - for jj in range(n_samples): - normalize_sum += sample_weight[jj] * X_mean_ii ** 2 - R[jj] += sample_weight[jj] * X_mean_ii * w_ii - R_sum += R[jj] - norm_cols_X[ii] = normalize_sum + for i in range(n_samples): + normalize_sum += sample_weight[i] * X_mean_j ** 2 + R[i] += sample_weight[i] * X_mean_j * w_j + R_sum += R[i] + norm2_cols_X[j] = normalize_sum startptr = endptr # Note: No need to update R_sum from here on because the update terms cancel - # each other: w_ii * np.sum(X[:,ii] - X_mean[ii]) = 0. R_sum is only ever + # each other: w_j * np.sum(X[:,j] - X_mean[j]) = 0. R_sum is only ever # needed and calculated if X_mean is provided. # tol *= np.dot(y, y) @@ -560,69 +560,69 @@ def sparse_enet_coordinate_descent( for f_iter in range(n_features): # Loop over coordinates if random: - ii = rand_int(n_features, rand_r_state) + j = rand_int(n_features, rand_r_state) else: - ii = f_iter + j = f_iter - if norm_cols_X[ii] == 0.0: + if norm2_cols_X[j] == 0.0: continue - startptr = X_indptr[ii] - endptr = X_indptr[ii + 1] - w_ii = w[ii] # Store previous value - X_mean_ii = X_mean[ii] + startptr = X_indptr[j] + endptr = X_indptr[j + 1] + w_j = w[j] # Store previous value + X_mean_j = X_mean[j] - if w_ii != 0.0: - # R += w_ii * X[:,ii] + if w_j != 0.0: + # R += w_j * X[:,j] if no_sample_weights: - for jj in range(startptr, endptr): - R[X_indices[jj]] += X_data[jj] * w_ii + for i in range(startptr, endptr): + R[X_indices[i]] += X_data[i] * w_j if center: - for jj in range(n_samples): - R[jj] -= X_mean_ii * w_ii + for i in range(n_samples): + R[i] -= X_mean_j * w_j else: - for jj in range(startptr, endptr): - tmp = sample_weight[X_indices[jj]] - R[X_indices[jj]] += tmp * X_data[jj] * w_ii + for i in range(startptr, endptr): + tmp = sample_weight[X_indices[i]] + R[X_indices[i]] += tmp * X_data[i] * w_j if center: - for jj in range(n_samples): - R[jj] -= sample_weight[jj] * X_mean_ii * w_ii + for i in range(n_samples): + R[i] -= sample_weight[i] * X_mean_j * w_j - # tmp = (X[:,ii] * R).sum() + # tmp = (X[:,j] * R).sum() tmp = 0.0 - for jj in range(startptr, endptr): - tmp += R[X_indices[jj]] * X_data[jj] + for i in range(startptr, endptr): + tmp += R[X_indices[i]] * X_data[i] if center: - tmp -= R_sum * X_mean_ii + tmp -= R_sum * X_mean_j if positive and tmp < 0.0: - w[ii] = 0.0 + w[j] = 0.0 else: - w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ - / (norm_cols_X[ii] + beta) + w[j] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ + / (norm2_cols_X[j] + beta) - if w[ii] != 0.0: - # R -= w[ii] * X[:,ii] # Update residual + if w[j] != 0.0: + # R -= w[j] * X[:,j] # Update residual if no_sample_weights: - for jj in range(startptr, endptr): - R[X_indices[jj]] -= X_data[jj] * w[ii] + for i in range(startptr, endptr): + R[X_indices[i]] -= X_data[i] * w[j] if center: - for jj in range(n_samples): - R[jj] += X_mean_ii * w[ii] + for i in range(n_samples): + R[i] += X_mean_j * w[j] else: - for jj in range(startptr, endptr): - tmp = sample_weight[X_indices[jj]] - R[X_indices[jj]] -= tmp * X_data[jj] * w[ii] + for i in range(startptr, endptr): + tmp = sample_weight[X_indices[i]] + R[X_indices[i]] -= tmp * X_data[i] * w[j] if center: - for jj in range(n_samples): - R[jj] += sample_weight[jj] * X_mean_ii * w[ii] + for i in range(n_samples): + R[i] += sample_weight[i] * X_mean_j * w[j] # update the maximum absolute coefficient update - d_w_ii = fabs(w[ii] - w_ii) - d_w_max = fmax(d_w_max, d_w_ii) + d_w_j = fabs(w[j] - w_j) + d_w_max = fmax(d_w_max, d_w_j) - w_max = fmax(w_max, fabs(w[ii])) + w_max = fmax(w_max, fabs(w[j])) if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1: # the biggest coordinate update of this iteration was smaller than @@ -631,14 +631,14 @@ def sparse_enet_coordinate_descent( # XtA = X.T @ R - beta * w # sparse X.T / dense R dot product - for ii in range(n_features): - XtA[ii] = 0.0 - for kk in range(X_indptr[ii], X_indptr[ii + 1]): - XtA[ii] += X_data[kk] * R[X_indices[kk]] + for j in range(n_features): + XtA[j] = 0.0 + for kk in range(X_indptr[j], X_indptr[j + 1]): + XtA[j] += X_data[kk] * R[X_indices[kk]] if center: - XtA[ii] -= X_mean[ii] * R_sum - XtA[ii] -= beta * w[ii] + XtA[j] -= X_mean[j] * R_sum + XtA[j] -= beta * w[j] if positive: dual_norm_XtA = max(n_features, &XtA[0]) @@ -650,10 +650,10 @@ def sparse_enet_coordinate_descent( R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) else: R_norm2 = 0.0 - for jj in range(n_samples): + for i in range(n_samples): # R is already multiplied by sample_weight - if sample_weight[jj] != 0: - R_norm2 += (R[jj] ** 2) / sample_weight[jj] + if sample_weight[i] != 0: + R_norm2 += (R[i] ** 2) / sample_weight[i] # w_norm2 = np.dot(w, w) w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) From 1708481e4bcbcd0142cd01b74904bcfa6168ae92 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Aug 2025 16:44:33 +0200 Subject: [PATCH 2/8] MNT add R_plus_wj_Xj --- sklearn/linear_model/_cd_fast.pyx | 85 +++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 26 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 92a4122dad484..cd850399f6e8b 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -401,6 +401,39 @@ def enet_coordinate_descent( return np.asarray(w), gap, tol, n_iter + 1 +cdef inline void R_plus_wj_Xj( + unsigned int n_samples, + floating[::1] R, # out + const floating[::1] X_data, + const int[::1] X_indices, + const int[::1] X_indptr, + const floating[::1] X_mean, + bint center, + const floating[::1] sample_weight, + bint no_sample_weights, + floating w_j, + unsigned int j, +) noexcept nogil: + """R += w_j * X[:,j]""" + cdef unsigned int startptr = X_indptr[j] + cdef unsigned int endptr = X_indptr[j + 1] + cdef floating sw + cdef floating X_mean_j = X_mean[j] + if no_sample_weights: + for i in range(startptr, endptr): + R[X_indices[i]] += X_data[i] * w_j + if center: + for i in range(n_samples): + R[i] -= X_mean_j * w_j + else: + for i in range(startptr, endptr): + sw = sample_weight[X_indices[i]] + R[X_indices[i]] += sw * X_data[i] * w_j + if center: + for i in range(n_samples): + R[i] -= sample_weight[i] * X_mean_j * w_j + + def sparse_enet_coordinate_descent( floating[::1] w, floating alpha, @@ -574,19 +607,19 @@ def sparse_enet_coordinate_descent( if w_j != 0.0: # R += w_j * X[:,j] - if no_sample_weights: - for i in range(startptr, endptr): - R[X_indices[i]] += X_data[i] * w_j - if center: - for i in range(n_samples): - R[i] -= X_mean_j * w_j - else: - for i in range(startptr, endptr): - tmp = sample_weight[X_indices[i]] - R[X_indices[i]] += tmp * X_data[i] * w_j - if center: - for i in range(n_samples): - R[i] -= sample_weight[i] * X_mean_j * w_j + R_plus_wj_Xj( + n_samples, + R, + X_data, + X_indices, + X_indptr, + X_mean, + center, + sample_weight, + no_sample_weights, + w_j, + j, + ) # tmp = (X[:,j] * R).sum() tmp = 0.0 @@ -604,19 +637,19 @@ def sparse_enet_coordinate_descent( if w[j] != 0.0: # R -= w[j] * X[:,j] # Update residual - if no_sample_weights: - for i in range(startptr, endptr): - R[X_indices[i]] -= X_data[i] * w[j] - if center: - for i in range(n_samples): - R[i] += X_mean_j * w[j] - else: - for i in range(startptr, endptr): - tmp = sample_weight[X_indices[i]] - R[X_indices[i]] -= tmp * X_data[i] * w[j] - if center: - for i in range(n_samples): - R[i] += sample_weight[i] * X_mean_j * w[j] + R_plus_wj_Xj( + n_samples, + R, + X_data, + X_indices, + X_indptr, + X_mean, + center, + sample_weight, + no_sample_weights, + -w[j], + j, + ) # update the maximum absolute coefficient update d_w_j = fabs(w[j] - w_j) From d4f47e732e7147804dc48b247ac98cbb2bc36327 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Aug 2025 17:09:24 +0200 Subject: [PATCH 3/8] MNT add gap safe rule to sparse_enet_coordinate_descent --- sklearn/linear_model/_cd_fast.pyx | 252 ++++++++++++++---- sklearn/linear_model/_coordinate_descent.py | 1 + .../tests/test_coordinate_descent.py | 52 ++-- 3 files changed, 228 insertions(+), 77 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index cd850399f6e8b..b3f5dacd4ed50 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -434,6 +434,83 @@ cdef inline void R_plus_wj_Xj( R[i] -= sample_weight[i] * X_mean_j * w_j +cdef (floating, floating) sparse_gap_enet( + int n_samples, + int n_features, + const floating[::1] w, + floating alpha, # L1 penalty + floating beta, # L2 penalty + const floating[::1] X_data, + const int[::1] X_indices, + const int[::1] X_indptr, + const floating[::1] y, + const floating[::1] sample_weight, + bint no_sample_weights, + const floating[::1] X_mean, + bint center, + const floating[::1] R, # current residuals = y - X @ w + floating R_sum, + floating[::1] XtA, # XtA = X.T @ R - beta * w is calculated inplace + bint positive, +) noexcept nogil: + """Compute dual gap for use in sparse_enet_coordinate_descent.""" + cdef floating gap = 0.0 + cdef floating dual_norm_XtA + cdef floating R_norm2 + cdef floating w_norm2 = 0.0 + cdef floating l1_norm + cdef floating A_norm2 + cdef floating const_ + cdef unsigned int i, j + + # XtA = X.T @ R - beta * w + # sparse X.T @ dense R + for j in range(n_features): + XtA[j] = 0.0 + for i in range(X_indptr[j], X_indptr[j + 1]): + XtA[j] += X_data[i] * R[X_indices[i]] + + if center: + XtA[j] -= X_mean[j] * R_sum + XtA[j] -= beta * w[j] + + if positive: + dual_norm_XtA = max(n_features, &XtA[0]) + else: + dual_norm_XtA = abs_max(n_features, &XtA[0]) + + # R_norm2 = R @ R + if no_sample_weights: + R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) + else: + R_norm2 = 0.0 + for i in range(n_samples): + # R is already multiplied by sample_weight + if sample_weight[i] != 0: + R_norm2 += (R[i] ** 2) / sample_weight[i] + + # w_norm2 = w @ w + if beta > 0: + w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) + + if (dual_norm_XtA > alpha): + const_ = alpha / dual_norm_XtA + A_norm2 = R_norm2 * const_**2 + gap = 0.5 * (R_norm2 + A_norm2) + else: + const_ = 1.0 + gap = R_norm2 + + l1_norm = _asum(n_features, &w[0], 1) + + gap += ( + alpha * l1_norm + - const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # R @ y + + 0.5 * beta * (1 + const_ ** 2) * w_norm2 + ) + return gap, dual_norm_XtA + + def sparse_enet_coordinate_descent( floating[::1] w, floating alpha, @@ -449,6 +526,7 @@ def sparse_enet_coordinate_descent( object rng, bint random=0, bint positive=0, + bint do_screening=1, ): """Cython version of the coordinate descent algorithm for Elastic-Net @@ -464,6 +542,8 @@ def sparse_enet_coordinate_descent( and X_mean is the weighted average of X (per column). + The rest is the same as enet_coordinate_descent, but for sparse X. + Returns ------- w : ndarray of shape (n_features,) @@ -513,12 +593,11 @@ def sparse_enet_coordinate_descent( cdef floating dual_norm_XtA cdef floating X_mean_j cdef floating R_sum = 0.0 - cdef floating R_norm2 - cdef floating w_norm2 - cdef floating l1_norm - cdef floating const_ - cdef floating A_norm2 cdef floating normalize_sum + cdef unsigned int n_active = n_features + cdef uint32_t[::1] active_set + # TODO: use binset insteaf of array of bools + cdef uint8_t[::1] excluded_set cdef unsigned int i cdef unsigned int j cdef unsigned int n_iter = 0 @@ -529,7 +608,10 @@ def sparse_enet_coordinate_descent( cdef uint32_t* rand_r_state = &rand_r_state_seed cdef bint center = False cdef bint no_sample_weights = sample_weight is None - cdef int kk + + if do_screening: + active_set = np.empty(n_features, dtype=np.uint32) # map [:n_active] -> j + excluded_set = np.empty(n_features, dtype=np.uint8) if no_sample_weights: yw = y @@ -586,17 +668,75 @@ def sparse_enet_coordinate_descent( # with sample weights: tol *= y @ (sw * y) tol *= _dot(n_samples, &y[0], 1, &yw[0], 1) - for n_iter in range(max_iter): + # Check convergence before entering the main loop. + gap, dual_norm_XtA = sparse_gap_enet( + n_samples, + n_features, + w, + alpha, + beta, + X_data, + X_indices, + X_indptr, + y, + sample_weight, + no_sample_weights, + X_mean, + center, + R, + R_sum, + XtA, + positive, + ) + if gap <= tol: + with gil: + return np.asarray(w), gap, tol, 0 + # Gap Safe Screening Rules, see https://arxiv.org/abs/1802.07481, Eq. 11 + if do_screening: + n_active = 0 + for j in range(n_features): + if norm2_cols_X[j] == 0: + w[j] = 0 + excluded_set[j] = 1 + continue + Xj_theta = XtA[j] / fmax(alpha, dual_norm_XtA) # X[:,j] @ dual_theta + d_j = (1 - fabs(Xj_theta)) / sqrt(norm2_cols_X[j] + beta) + if d_j <= sqrt(2 * gap) / alpha: + # include feature j + active_set[n_active] = j + excluded_set[j] = 0 + n_active += 1 + else: + # R += w[j] * X[:,j] + R_plus_wj_Xj( + n_samples, + R, + X_data, + X_indices, + X_indptr, + X_mean, + center, + sample_weight, + no_sample_weights, + w[j], + j, + ) + w[j] = 0 + excluded_set[j] = 1 + + for n_iter in range(max_iter): w_max = 0.0 d_w_max = 0.0 - - for f_iter in range(n_features): # Loop over coordinates + for f_iter in range(n_active): # Loop over coordinates if random: - j = rand_int(n_features, rand_r_state) + j = rand_int(n_active, rand_r_state) else: j = f_iter + if do_screening: + j = active_set[j] + if norm2_cols_X[j] == 0.0: continue @@ -661,53 +801,61 @@ def sparse_enet_coordinate_descent( # the biggest coordinate update of this iteration was smaller than # the tolerance: check the duality gap as ultimate stopping # criterion - - # XtA = X.T @ R - beta * w - # sparse X.T / dense R dot product - for j in range(n_features): - XtA[j] = 0.0 - for kk in range(X_indptr[j], X_indptr[j + 1]): - XtA[j] += X_data[kk] * R[X_indices[kk]] - - if center: - XtA[j] -= X_mean[j] * R_sum - XtA[j] -= beta * w[j] - - if positive: - dual_norm_XtA = max(n_features, &XtA[0]) - else: - dual_norm_XtA = abs_max(n_features, &XtA[0]) - - # R_norm2 = np.dot(R, R) - if no_sample_weights: - R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) - else: - R_norm2 = 0.0 - for i in range(n_samples): - # R is already multiplied by sample_weight - if sample_weight[i] != 0: - R_norm2 += (R[i] ** 2) / sample_weight[i] - - # w_norm2 = np.dot(w, w) - w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) - if (dual_norm_XtA > alpha): - const_ = alpha / dual_norm_XtA - A_norm2 = R_norm2 * const_**2 - gap = 0.5 * (R_norm2 + A_norm2) - else: - const_ = 1.0 - gap = R_norm2 - - l1_norm = _asum(n_features, &w[0], 1) - - gap += (alpha * l1_norm - - const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y) - + 0.5 * beta * (1 + const_ ** 2) * w_norm2) + gap, dual_norm_XtA = sparse_gap_enet( + n_samples, + n_features, + w, + alpha, + beta, + X_data, + X_indices, + X_indptr, + y, + sample_weight, + no_sample_weights, + X_mean, + center, + R, + R_sum, + XtA, + positive, + ) if gap <= tol: # return if we reached desired tolerance break + # Gap Safe Screening Rules, see https://arxiv.org/abs/1802.07481, Eq. 11 + if do_screening: + n_active = 0 + for j in range(n_features): + if excluded_set[j]: + continue + Xj_theta = XtA[j] / fmax(alpha, dual_norm_XtA) # X @ dual_theta + d_j = (1 - fabs(Xj_theta)) / sqrt(norm2_cols_X[j] + beta) + if d_j <= sqrt(2 * gap) / alpha: + # include feature j + active_set[n_active] = j + excluded_set[j] = 0 + n_active += 1 + else: + # R += w[j] * X[:,j] + R_plus_wj_Xj( + n_samples, + R, + X_data, + X_indices, + X_indptr, + X_mean, + center, + sample_weight, + no_sample_weights, + w[j], + j, + ) + w[j] = 0 + excluded_set[j] = 1 + else: # for/else, runs if for doesn't end with a `break` with gil: diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index abf1f13de8c23..a6af31c3b0071 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -687,6 +687,7 @@ def enet_path( rng=rng, random=random, positive=positive, + do_screening=do_screening, ) elif multi_output: model = cd_fast.enet_coordinate_descent_multi_task( diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index aa073b9a5080b..0b1ac1faa0a9c 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -103,16 +103,19 @@ def test_cython_solver_equivalence(): "positive": False, } - coef_1 = np.zeros(X.shape[1]) - coef_2, coef_3, coef_4 = coef_1.copy(), coef_1.copy(), coef_1.copy() + def zc(): + """Create a new zero coefficient array (zc).""" + return np.zeros(X.shape[1]) # For alpha_max, coefficients must all be zero. + coef_1 = zc() cd_fast.enet_coordinate_descent( w=coef_1, alpha=alpha_max, X=X_centered, y=y, **params ) assert_allclose(coef_1, 0) # Without gap safe screening rules + coef_1 = zc() cd_fast.enet_coordinate_descent( w=coef_1, alpha=alpha, X=X_centered, y=y, **params, do_screening=False ) @@ -120,6 +123,7 @@ def test_cython_solver_equivalence(): assert 2 <= np.sum(np.abs(coef_1) > 1e-8) < X.shape[1] # With gap safe screening rules + coef_2 = zc() cd_fast.enet_coordinate_descent( w=coef_2, alpha=alpha, X=X_centered, y=y, **params, do_screening=True ) @@ -127,20 +131,24 @@ def test_cython_solver_equivalence(): # Sparse Xs = sparse.csc_matrix(X) - cd_fast.sparse_enet_coordinate_descent( - w=coef_3, - alpha=alpha, - X_data=Xs.data, - X_indices=Xs.indices, - X_indptr=Xs.indptr, - y=y, - sample_weight=None, - X_mean=X_mean, - **params, - ) - assert_allclose(coef_3, coef_1) + for do_screening in [True, False]: + coef_3 = zc() + cd_fast.sparse_enet_coordinate_descent( + w=coef_3, + alpha=alpha, + X_data=Xs.data, + X_indices=Xs.indices, + X_indptr=Xs.indptr, + y=y, + sample_weight=None, + X_mean=X_mean, + **params, + do_screening=do_screening, + ) + assert_allclose(coef_3, coef_1) # Gram + coef_4 = zc() cd_fast.enet_coordinate_descent_gram( w=coef_4, alpha=alpha, @@ -842,14 +850,8 @@ def test_warm_start_convergence(sparse_X): model.set_params(warm_start=True) model.fit(X, y) n_iter_warm_start = model.n_iter_ - if sparse_X: - # TODO: sparse_enet_coordinate_descent is not yet updated. - # Fit the same model again, using a warm start: the optimizer just performs - # a single pass before checking that it has already converged - assert n_iter_warm_start == 1 - else: - # enet_coordinate_descent checks dual gap before entering the main loop - assert n_iter_warm_start == 0 + # coordinate descent checks dual gap before entering the main loop + assert n_iter_warm_start == 0 def test_warm_start_convergence_with_regularizer_decrement(): @@ -940,9 +942,9 @@ def test_sparse_dense_descent_paths(csr_container): X, y, _, _ = build_dataset(n_samples=50, n_features=20) csr = csr_container(X) for path in [enet_path, lasso_path]: - _, coefs, _ = path(X, y) - _, sparse_coefs, _ = path(csr, y) - assert_array_almost_equal(coefs, sparse_coefs) + _, coefs, _ = path(X, y, tol=1e-10) + _, sparse_coefs, _ = path(csr, y, tol=1e-10) + assert_allclose(coefs, sparse_coefs) @pytest.mark.parametrize("path_func", [enet_path, lasso_path]) From e98cf6eb33d2dd97c72a4ee9a01cbf3cac49781e Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Aug 2025 22:00:44 +0200 Subject: [PATCH 4/8] CLN forgotten declaration of 2 variables --- sklearn/linear_model/_cd_fast.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index b3f5dacd4ed50..308c1443b0c5b 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -583,6 +583,8 @@ def sparse_enet_coordinate_descent( cdef floating[::1] XtA = np.empty(n_features, dtype=dtype) cdef const floating[::1] yw + cdef floating d_j + cdef floating Xj_theta cdef floating tmp cdef floating w_j cdef floating d_w_max From b0ca2f92f100b1e0243117763f65e6749f4a2329 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Aug 2025 22:02:22 +0200 Subject: [PATCH 5/8] TST fix test_same_output_sparse_dense_lasso_and_enet_cv --- .../tests/test_sparse_coordinate_descent.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py index d0472778aac22..d7d85763f8a86 100644 --- a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py @@ -306,23 +306,23 @@ def test_sparse_dense_equality( @pytest.mark.parametrize("csc_container", CSC_CONTAINERS) def test_same_output_sparse_dense_lasso_and_enet_cv(csc_container): X, y = make_sparse_data(csc_container, n_samples=40, n_features=10) - clfs = ElasticNetCV(max_iter=100) + clfs = ElasticNetCV(max_iter=100, tol=1e-7) clfs.fit(X, y) - clfd = ElasticNetCV(max_iter=100) + clfd = ElasticNetCV(max_iter=100, tol=1e-7) clfd.fit(X.toarray(), y) - assert_almost_equal(clfs.alpha_, clfd.alpha_, 7) - assert_almost_equal(clfs.intercept_, clfd.intercept_, 7) - assert_array_almost_equal(clfs.mse_path_, clfd.mse_path_) - assert_array_almost_equal(clfs.alphas_, clfd.alphas_) + assert_allclose(clfs.alpha_, clfd.alpha_) + assert_allclose(clfs.intercept_, clfd.intercept_) + assert_allclose(clfs.mse_path_, clfd.mse_path_) + assert_allclose(clfs.alphas_, clfd.alphas_) - clfs = LassoCV(max_iter=100, cv=4) + clfs = LassoCV(max_iter=100, cv=4, tol=1e-8) clfs.fit(X, y) - clfd = LassoCV(max_iter=100, cv=4) + clfd = LassoCV(max_iter=100, cv=4, tol=1e-8) clfd.fit(X.toarray(), y) - assert_almost_equal(clfs.alpha_, clfd.alpha_, 7) - assert_almost_equal(clfs.intercept_, clfd.intercept_, 7) - assert_array_almost_equal(clfs.mse_path_, clfd.mse_path_) - assert_array_almost_equal(clfs.alphas_, clfd.alphas_) + assert_allclose(clfs.alpha_, clfd.alpha_) + assert_allclose(clfs.intercept_, clfd.intercept_) + assert_allclose(clfs.mse_path_, clfd.mse_path_) + assert_allclose(clfs.alphas_, clfd.alphas_) @pytest.mark.parametrize("coo_container", COO_CONTAINERS) From 72e4bde7aef9435897f5a50e6f33f9418fb69401 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 22 Aug 2025 13:37:41 +0200 Subject: [PATCH 6/8] DOC add whatsnew --- .../{31882.efficiency.rst => 31986.efficiency.rst} | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) rename doc/whats_new/upcoming_changes/sklearn.linear_model/{31882.efficiency.rst => 31986.efficiency.rst} (80%) diff --git a/doc/whats_new/upcoming_changes/sklearn.linear_model/31882.efficiency.rst b/doc/whats_new/upcoming_changes/sklearn.linear_model/31986.efficiency.rst similarity index 80% rename from doc/whats_new/upcoming_changes/sklearn.linear_model/31882.efficiency.rst rename to doc/whats_new/upcoming_changes/sklearn.linear_model/31986.efficiency.rst index 55e0679b4b375..66d341e58f8ec 100644 --- a/doc/whats_new/upcoming_changes/sklearn.linear_model/31882.efficiency.rst +++ b/doc/whats_new/upcoming_changes/sklearn.linear_model/31986.efficiency.rst @@ -1,11 +1,12 @@ - :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`, :class:`linear_model.Lasso`, :class:`linear_model.LassoCV` as well as :func:`linear_model.lasso_path` and :func:`linear_model.enet_path` now implement - gap safe screening rules in the coordinate descent solver for dense `X` and - `precompute=False` or `"auto"` with `n_samples < n_features`. + gap safe screening rules in the coordinate descent solver for dense `X` (with + `precompute=False` or `"auto"` with `n_samples < n_features`) and sparse `X` + (always). The speedup of fitting time is particularly pronounced (10-times is possible) when computing regularization paths like the \*CV-variants of the above estimators do. There is now an additional check of the stopping criterion before entering the main loop of descent steps. As the stopping criterion requires the computation of the dual gap, the screening happens whenever the dual gap is computed. - By :user:`Christian Lorentzen `. + By :user:`Christian Lorentzen ` :pr:`31882` and From cefe385824fa6a229816775931f70250d4ea7aa7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 26 Aug 2025 07:49:55 +0200 Subject: [PATCH 7/8] CLN rename sparse_gap_enet to gap_enet_sparse --- sklearn/linear_model/_cd_fast.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 308c1443b0c5b..3cc01d5d57bcc 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -434,7 +434,7 @@ cdef inline void R_plus_wj_Xj( R[i] -= sample_weight[i] * X_mean_j * w_j -cdef (floating, floating) sparse_gap_enet( +cdef (floating, floating) gap_enet_sparse( int n_samples, int n_features, const floating[::1] w, @@ -671,7 +671,7 @@ def sparse_enet_coordinate_descent( tol *= _dot(n_samples, &y[0], 1, &yw[0], 1) # Check convergence before entering the main loop. - gap, dual_norm_XtA = sparse_gap_enet( + gap, dual_norm_XtA = gap_enet_sparse( n_samples, n_features, w, @@ -803,7 +803,7 @@ def sparse_enet_coordinate_descent( # the biggest coordinate update of this iteration was smaller than # the tolerance: check the duality gap as ultimate stopping # criterion - gap, dual_norm_XtA = sparse_gap_enet( + gap, dual_norm_XtA = gap_enet_sparse( n_samples, n_features, w, From 130a37b3c458d95bed1f509eb658945e38fbea28 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 26 Aug 2025 22:34:19 +0200 Subject: [PATCH 8/8] DOC fix upper bound dual gap --- doc/modules/linear_model.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index 2492e84cab38a..46bf23bce3b09 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -319,12 +319,14 @@ It stops if the duality gap is smaller than the provided tolerance `tol`. The duality gap :math:`G(w, v)` is an upper bound of the difference between the current primal objective function of the Lasso, :math:`P(w)`, and its minimum - :math:`P(w^\star)`, i.e. :math:`G(w, v) \leq P(w) - P(w^\star)`. It is given by + :math:`P(w^\star)`, i.e. :math:`G(w, v) \geq P(w) - P(w^\star)`. It is given by :math:`G(w, v) = P(w) - D(v)` with dual objective function .. math:: D(v) = \frac{1}{2n_{\text{samples}}}(y^Tv - ||v||_2^2) subject to :math:`v \in ||X^Tv||_{\infty} \leq n_{\text{samples}}\alpha`. + At optimum, the duality gap is zero, :math:`G(w^\star, v^\star) = 0` (a property + called strong duality). With (scaled) dual variable :math:`v = c r`, current residual :math:`r = y - Xw` and dual scaling