Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@jemus42
Copy link
Member

@jemus42 jemus42 commented Sep 25, 2025

$importance() at least aggregates importance scores using the measure's aggregator, e.g. mean.

Addtionally now, depending on argument variance_method:

  • none: Current behavior, no variance estimation
  • raw: Raw, uncorrected variances across all resamplings, refits, etc.
  • nadeau_bengio: Following the recommendation from Molnar et al. (2023) to use the 1/m + (n2/n1) correction factor. Only allowed for bootstrap / subsampling.

Since $importance() is (at least for now) general for all feature importance methods, this applies to PFI the same way it does to CFI or MarginalSAGE etc. and probably needs a few big warning labels to make it clear that it may or may not be reasonable to use with anything that isn't PFI.

@jemus42 jemus42 changed the title WIP variance methods for $importance() WIP variance methods for $importance() Sep 25, 2025
@jemus42

This comment was marked as outdated.

@jemus42 jemus42 mentioned this pull request Sep 29, 2025
@jemus42
Copy link
Member Author

jemus42 commented Sep 29, 2025

library(xplainfi)
library(mlr3learners)
#> Loading required package: mlr3
library(ggplot2)

set.seed(123)

n = 500

pfi = PFI$new(
  task = sim_dgp_interactions(n = n),
  learner = lrn("regr.ranger", num.trees = 500),
  measure = msr("regr.mse"),
  # Subsampling instead of bootstrapping due to RF
  resampling = rsmp("subsampling", repeats = 15),
  # for stability of PFI estimates within resampling
  iters_perm = 5
)

pfi$compute()

# No variance estimation: "Safe"
pfi$importance(variance_method = "none")
#> Key: <feature>
#>    feature   importance
#>     <char>        <num>
#> 1:  noise1 -0.005907559
#> 2:  noise2  0.003732088
#> 3:      x1  1.150619588
#> 4:      x2  1.654139475
#> 5:      x3  1.571509170
# Raw variances for explicitly wrong CIs for comparison purposes only
pfi$importance(variance_method = "raw")
#> Key: <feature>
#>    feature   importance         se  conf_lower conf_upper
#>     <char>        <num>      <num>       <num>      <num>
#> 1:  noise1 -0.005907559 0.01901551 -0.04669177 0.03487665
#> 2:  noise2  0.003732088 0.01307179 -0.02430410 0.03176828
#> 3:      x1  1.150619588 0.08063310  0.97767880 1.32356038
#> 4:      x2  1.654139475 0.07344207  1.49662190 1.81165705
#> 5:      x3  1.571509170 0.05615073  1.45107783 1.69194051
# Better-than-uncorrected CIs
pfi$importance(variance_method = "nadeau_bengio")
#> Key: <feature>
#>    feature   importance         se  conf_lower conf_upper
#>     <char>        <num>      <num>       <num>      <num>
#> 1:  noise1 -0.005907559 0.05551266 -0.12497038 0.11315526
#> 2:  noise2  0.003732088 0.03816093 -0.07811498 0.08557915
#> 3:      x1  1.150619588 0.23539510  0.64574731 1.65549187
#> 4:      x2  1.654139475 0.21440208  1.19429274 2.11398621
#> 5:      x3  1.571509170 0.16392285  1.21992962 1.92308872

# Combine to compare
pfis = data.table::rbindlist(list(
  pfi$importance(variance_method = "raw")[, var_method := "raw"][],
  pfi$importance(variance_method = "nadeau_bengio")[, var_method := "nadeau_bengio"]
))[, width := conf_upper - conf_lower][order(xtfrm(feature))]
pfis
#>     feature   importance         se  conf_lower conf_upper    var_method
#>      <char>        <num>      <num>       <num>      <num>        <char>
#>  1:  noise1 -0.005907559 0.01901551 -0.04669177 0.03487665           raw
#>  2:  noise1 -0.005907559 0.05551266 -0.12497038 0.11315526 nadeau_bengio
#>  3:  noise2  0.003732088 0.01307179 -0.02430410 0.03176828           raw
#>  4:  noise2  0.003732088 0.03816093 -0.07811498 0.08557915 nadeau_bengio
#>  5:      x1  1.150619588 0.08063310  0.97767880 1.32356038           raw
#>  6:      x1  1.150619588 0.23539510  0.64574731 1.65549187 nadeau_bengio
#>  7:      x2  1.654139475 0.07344207  1.49662190 1.81165705           raw
#>  8:      x2  1.654139475 0.21440208  1.19429274 2.11398621 nadeau_bengio
#>  9:      x3  1.571509170 0.05615073  1.45107783 1.69194051           raw
#> 10:      x3  1.571509170 0.16392285  1.21992962 1.92308872 nadeau_bengio
#>          width
#>          <num>
#>  1: 0.08156843
#>  2: 0.23812564
#>  3: 0.05607238
#>  4: 0.16369413
#>  5: 0.34588159
#>  6: 1.00974455
#>  7: 0.31503516
#>  8: 0.91969347
#>  9: 0.24086269
#> 10: 0.70315911


pfis |>
  dplyr::filter(grepl("^x", feature)) |>
  ggplot(aes(y = feature, x = importance, color = var_method)) +
  geom_errorbar(aes(xmin = conf_lower, xmax = conf_upper), linewidth = 1, position = "dodge") +
  geom_point(size = 1.5, shape = 21, position = position_dodge(width = .9)) +
  scale_color_brewer(palette = "Dark2") +
  labs(
    title = "PFI with confidence intervals",
    subtitle = "With and without variance correction",
    x = "PFI",
    y = "Feature",
    color = "Variance method"
  ) +
  theme_minimal(base_size = 14) +
  theme(legend.position = "top")

Created on 2025-09-29 with reprex v2.1.1

@jemus42 jemus42 merged commit ff4b225 into main Sep 30, 2025
7 checks passed
@jemus42 jemus42 deleted the variance-est branch September 30, 2025 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants