-
-
Notifications
You must be signed in to change notification settings - Fork 1
WIP variance methods for $importance()
#40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$importance()
This comment was marked as outdated.
This comment was marked as outdated.
Closed
Member
Author
library(xplainfi)
library(mlr3learners)
#> Loading required package: mlr3
library(ggplot2)
set.seed(123)
n = 500
pfi = PFI$new(
task = sim_dgp_interactions(n = n),
learner = lrn("regr.ranger", num.trees = 500),
measure = msr("regr.mse"),
# Subsampling instead of bootstrapping due to RF
resampling = rsmp("subsampling", repeats = 15),
# for stability of PFI estimates within resampling
iters_perm = 5
)
pfi$compute()
# No variance estimation: "Safe"
pfi$importance(variance_method = "none")
#> Key: <feature>
#> feature importance
#> <char> <num>
#> 1: noise1 -0.005907559
#> 2: noise2 0.003732088
#> 3: x1 1.150619588
#> 4: x2 1.654139475
#> 5: x3 1.571509170
# Raw variances for explicitly wrong CIs for comparison purposes only
pfi$importance(variance_method = "raw")
#> Key: <feature>
#> feature importance se conf_lower conf_upper
#> <char> <num> <num> <num> <num>
#> 1: noise1 -0.005907559 0.01901551 -0.04669177 0.03487665
#> 2: noise2 0.003732088 0.01307179 -0.02430410 0.03176828
#> 3: x1 1.150619588 0.08063310 0.97767880 1.32356038
#> 4: x2 1.654139475 0.07344207 1.49662190 1.81165705
#> 5: x3 1.571509170 0.05615073 1.45107783 1.69194051
# Better-than-uncorrected CIs
pfi$importance(variance_method = "nadeau_bengio")
#> Key: <feature>
#> feature importance se conf_lower conf_upper
#> <char> <num> <num> <num> <num>
#> 1: noise1 -0.005907559 0.05551266 -0.12497038 0.11315526
#> 2: noise2 0.003732088 0.03816093 -0.07811498 0.08557915
#> 3: x1 1.150619588 0.23539510 0.64574731 1.65549187
#> 4: x2 1.654139475 0.21440208 1.19429274 2.11398621
#> 5: x3 1.571509170 0.16392285 1.21992962 1.92308872
# Combine to compare
pfis = data.table::rbindlist(list(
pfi$importance(variance_method = "raw")[, var_method := "raw"][],
pfi$importance(variance_method = "nadeau_bengio")[, var_method := "nadeau_bengio"]
))[, width := conf_upper - conf_lower][order(xtfrm(feature))]
pfis
#> feature importance se conf_lower conf_upper var_method
#> <char> <num> <num> <num> <num> <char>
#> 1: noise1 -0.005907559 0.01901551 -0.04669177 0.03487665 raw
#> 2: noise1 -0.005907559 0.05551266 -0.12497038 0.11315526 nadeau_bengio
#> 3: noise2 0.003732088 0.01307179 -0.02430410 0.03176828 raw
#> 4: noise2 0.003732088 0.03816093 -0.07811498 0.08557915 nadeau_bengio
#> 5: x1 1.150619588 0.08063310 0.97767880 1.32356038 raw
#> 6: x1 1.150619588 0.23539510 0.64574731 1.65549187 nadeau_bengio
#> 7: x2 1.654139475 0.07344207 1.49662190 1.81165705 raw
#> 8: x2 1.654139475 0.21440208 1.19429274 2.11398621 nadeau_bengio
#> 9: x3 1.571509170 0.05615073 1.45107783 1.69194051 raw
#> 10: x3 1.571509170 0.16392285 1.21992962 1.92308872 nadeau_bengio
#> width
#> <num>
#> 1: 0.08156843
#> 2: 0.23812564
#> 3: 0.05607238
#> 4: 0.16369413
#> 5: 0.34588159
#> 6: 1.00974455
#> 7: 0.31503516
#> 8: 0.91969347
#> 9: 0.24086269
#> 10: 0.70315911
pfis |>
dplyr::filter(grepl("^x", feature)) |>
ggplot(aes(y = feature, x = importance, color = var_method)) +
geom_errorbar(aes(xmin = conf_lower, xmax = conf_upper), linewidth = 1, position = "dodge") +
geom_point(size = 1.5, shape = 21, position = position_dodge(width = .9)) +
scale_color_brewer(palette = "Dark2") +
labs(
title = "PFI with confidence intervals",
subtitle = "With and without variance correction",
x = "PFI",
y = "Feature",
color = "Variance method"
) +
theme_minimal(base_size = 14) +
theme(legend.position = "top")Created on 2025-09-29 with reprex v2.1.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
$importance()at least aggregates importance scores using themeasure'saggregator, e.g.mean.Addtionally now, depending on argument
variance_method:none: Current behavior, no variance estimationraw: Raw, uncorrected variances across all resamplings, refits, etc.nadeau_bengio: Following the recommendation from Molnar et al. (2023) to use the1/m + (n2/n1)correction factor. Only allowed for bootstrap / subsampling.Since
$importance()is (at least for now) general for all feature importance methods, this applies toPFIthe same way it does toCFIorMarginalSAGEetc. and probably needs a few big warning labels to make it clear that it may or may not be reasonable to use with anything that isn't PFI.