Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 37c65e5

Browse files
committed
FIX: Use coordinate_descent_gram when precompute is True | auto
1 parent 083b5f5 commit 37c65e5

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,12 +467,16 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
467467
coef_, l1_reg, l2_reg, X.data, X.indices,
468468
X.indptr, y, X_sparse_scaling,
469469
max_iter, tol, positive)
470-
elif not multi_output:
471-
model = cd_fast.enet_coordinate_descent(
472-
coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive)
473-
else:
470+
elif multi_output:
474471
model = cd_fast.enet_coordinate_descent_multi_task(
475472
coef_, l1_reg, l2_reg, X, y, max_iter, tol)
473+
elif isinstance(precompute, np.ndarray):
474+
model = cd_fast.enet_coordinate_descent_gram(
475+
coef_, l1_reg, l2_reg, precompute, Xy, y, max_iter,
476+
tol, positive)
477+
else:
478+
model = cd_fast.enet_coordinate_descent(
479+
coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive)
476480
coef_, dual_gap_, eps_ = model
477481
coefs[..., i] = coef_
478482
dual_gaps[i] = dual_gap_

0 commit comments

Comments
 (0)