-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FIX : broken MultiTaskLasso with warm_start=True #12853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1796,7 +1796,7 @@ def fit(self, X, y): | |
X, y, X_offset, y_offset, X_scale = _preprocess_data( | ||
X, y, self.fit_intercept, self.normalize, copy=False) | ||
|
||
if not self.warm_start or self.coef_ is None: | ||
if not self.warm_start or not hasattr(self, "coef_"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jnothman shall I replace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm okay with it either way, actually. I just never know how flexible we should make these kinds of changes... I mean if someone's gone and changed their code now to set |
||
self.coef_ = np.zeros((n_tasks, n_features), dtype=X.dtype.type, | ||
order='F') | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we were trying to access a fit param before actually setting it.