Thanks to visit codestin.com
Credit goes to github.com

Skip to content

FIX : broken MultiTaskLasso with warm_start=True #12853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ def fit(self, X, y):
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, self.fit_intercept, self.normalize, copy=False)

if not self.warm_start or self.coef_ is None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were trying to access a fit param before actually setting it.

if not self.warm_start or not hasattr(self, "coef_"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman shall I replace not hasattr(self, "coef_") by a getattr and we're good?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with it either way, actually. I just never know how flexible we should make these kinds of changes... I mean if someone's gone and changed their code now to set coef_ = None so that it works, using getattr would make sure it continues to work in the next release... but it's also hacky.

self.coef_ = np.zeros((n_tasks, n_features), dtype=X.dtype.type,
order='F')

Expand Down
8 changes: 8 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,11 @@ def test_coef_shape_not_zero():
est_no_intercept = Lasso(fit_intercept=False)
est_no_intercept.fit(np.c_[np.ones(3)], np.ones(3))
assert est_no_intercept.coef_.shape == (1,)


def test_multi_task_lasso_warm_start():
X = np.array([[1, 2, 4, 5, 8], [3, 5, 7, 7, 8]]).T
y = np.array([12, 10, 11, 21, 5])[:, np.newaxis]

est = MultiTaskLasso(warm_start=True)
est.fit(X, y)