-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
FEA Add array API support to LabelBinarizer(sparse_output=False) for numeric labels
#32582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FEA Add array API support to LabelBinarizer(sparse_output=False) for numeric labels
#32582
Conversation
…r numeric labels
…inarizer_Array_API
|
I'll fix the CI later. |
…inarizer_Array_API
…inarizer_Array_API
|
CI is green, including CUDA. Ready for review, @ogrisel, @OmarManzoor! |
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @virchan
An initial set of comments.
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more comments. Generally looks good
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the existing code with array API support that relies on LabelBinarizer to check that this PR can help simplify it. I think we only have the following two occurrences, but I am not 100% sure:
scikit-learn/sklearn/linear_model/_ridge.py
Lines 1324 to 1329 in ca39ad1
# TODO: Update this line to avoid calling `_convert_to_numpy` # once LabelBinarizer has been updated to accept non-NumPy array API # compatible inputs. Y = self._label_binarizer.fit_transform( _convert_to_numpy(y, xp_y) if y_is_array_api else y ) scikit-learn/sklearn/metrics/_classification.py
Lines 185 to 193 in ca39ad1
# For classification metrics both array API compatible and non array API # compatible inputs are allowed for `y_true`. This is because arrays that # store class labels as strings cannot be represented in namespaces other # than Numpy. Thus to avoid unnecessary complexity, we always convert # `y_true` to a Numpy array so that it can be processed appropriately by # `LabelBinarizer` and then transfer the integer encoded output back to the # target namespace and device. if is_y_true_array_api: y_true = _convert_to_numpy(y_true, xp=xp_y_true)
OmarManzoor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM. Thank you @virchan
lucyleeow
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Just a few questions but looks good to me. I'm not that familiar with label binarizer though, so I would feel more comfortable if someone else looked at this...
f0dd7f7 to
0edd742
Compare
virchan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the force-push.
|
Update: CUDA CI passed after merging #32705. |
lucyleeow
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small question about the atol change, but LGTM - I think I had a good look so will just merge after 😬
Co-authored-by: Vivaan Nanavati <[email protected]> Co-authored-by: Maren Westermann <[email protected]>
This reverts commit 2fa9c93.
|
I ran the CI with the In particular, I updated |
Reference Issues/PRs
Towards #26024 and #32422 (comment).
What does this implement/fix? Explain your changes.
This PR adds Array API support to
LabelBinarizerandlabel_binarizewhensparse_output=Falsefor numeric labels, and therefore does not conflict with #30439 (comment). Specifically,Both
LabelBinarizerandlabel_binarizewill raise aValueErrorwhen the inputyhas a non-NumPy namespace andsparse_output=True.If
LabelBinarizeris fitted on a sparse matrix (i.e.,sparse_input_=True), callinginverse_transformon a non-NumPy array will raise aValueError.3. If the inputclassescontains string labels,label_binarizewill automatically fall back to the NumPy namespace.Any other comments?
Adjusted the
atolvalue in thetest_graphical_lassosfunction due to a CI failure with `random_seed=95, discovered during local testing.