-
Couldn't load subscription status.
- Fork 598
Remove double instantiations for cuml.explainer kernels
#7384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but do we want to remove the float or double instantiation? Is there any argument for one side or the other? Also instead of branching before the C++ call, would it be possible to use the BaseEstimator features to have the output type automatically match the input type while ensuring that the arrays internally always match the type that we choose for C++ instantiations.
python/cuml/cuml/explainer/base.pyx
Outdated
| masked_ptr_f32 = masked_inputs_f32.__cuda_array_interface__['data'][0] | ||
| bg_ptr_f32 = background_f32.__cuda_array_interface__['data'][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_cai_ptr here might be nice for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 8dc222c
| <bool> row_major) | ||
|
|
||
| # Cast result back to float64 | ||
| masked_inputs[:] = masked_inputs_f32.astype(cp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly please use get_cai_ptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 8dc222c
|
Thank you for the review @viclafargue!
Typically, it seems that we support float32 more often than double in other areas of the codebase? float32 also has a better memory footprint.
What is the process to enable the BaseEstimator features? So BaseEstimator automatically handles the casting of user input and output? |
Like : See estimator guide. But, just realized that the dtype doesn't come from data dtype, but from SHAPBase argument. I guess this is fine to route the C++ calls that way if we are doing the rest of calculations in the |
Resolves #7001
cuml.explaineralgorithms are used for model interpretability usually with tree models and are based on https://github.com/shap/shap.This PR removes
doublefunction support forkernel_dataset,permutation_shap_dataset,shap_main_effect_datasetupdate_perm_shap_valueswhich in turn removes
doubleinstantiations forcuml.explainerrelated kernels.Users may still pass in double inputs in the python layer, however they will be cast to float for the actual computation. The result will be cast back to double output.