-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
ENH Add Array API compatibility to KernelCenterer
#27556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ) | ||
|
|
||
| K_pred_cols = (np.sum(K, axis=1) / self.K_fit_rows_.shape[0])[:, np.newaxis] | ||
| K_pred_cols = (xp.sum(K, axis=1) / self.K_fit_rows_.shape[0])[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be replaced with xp.newaxis for readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried with xp.newaxis initially, but I got an AttributeError both with torch and numpy
AttributeError: module 'array_api_compat.torch' has no attribute 'newaxis'
AttributeError: module 'numpy.array_api' has no attribute 'newaxis'
@betatim do you know why newaxis may not be included?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I don't know why newaxis doesn't exist for them. I would have guessed "bug in array_api_compat" if that was the only place where its missing, but I think ofnumpy.array_api as the "very much exactly what the standard says" implementation. One option is that it is something that only got added in a newer version of the standard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened data-apis/array-api-compat#64 upstream.
Ok for using None in scikit-learn in the mean time.
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along with #27642 I get the MPS tests to pass for this PR on my local macOS laptop.
I also ran the tests with pytorch and cupy on a CUDA host and all array_api tests pass for the KernelCenterer class.
Reference Issues/PRs
Towards #26024
What does this implement/fix? Explain your changes.
It makes the
KernelCentererimplementation compatible and tested with the Array API.