diff --git a/NEWS.md b/NEWS.md index 6c94d92f..51eaea7c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,6 +18,7 @@ - `$convergence_history` tracks convergence history and can be analyzed to see per-feature values after each checkpoint - `$plot_convergence_history()` plots convergence history per feature - Convergence is tracked only for first resampling iteration + - Also add standard error tracking as part of the convergence history ([#33](https://github.com/jemus42/xplainfi/pull/33)) # xplainfi 0.1.0 diff --git a/R/SAGE.R b/R/SAGE.R index f7a596ae..e5b0e40b 100644 --- a/R/SAGE.R +++ b/R/SAGE.R @@ -2,29 +2,39 @@ #' #' @description Base class for SAGE (Shapley Additive Global Importance) #' feature importance based on Shapley values with marginalization. -#' This is an abstract class - use MarginalSAGE or ConditionalSAGE. +#' This is an abstract class - use [MarginalSAGE] or [ConditionalSAGE]. #' #' @details #' SAGE uses Shapley values to fairly distribute the total prediction #' performance among all features. Unlike perturbation-based methods, #' SAGE marginalizes features by integrating over their distribution. #' This is approximated by averaging predictions over a reference dataset. +#' +#' **Standard Error Calculation**: The standard errors (SE) reported in +#' `$convergence_history` reflect the uncertainty in Shapley value estimation +#' across different random permutations within a single resampling iteration. +#' These SEs quantify the Monte Carlo sampling error for a fixed trained model +#' and are only valid for inference about the importance of features for that +#' specific model. They do not capture broader uncertainty from model variability +#' across different train/test splits or resampling iterations. #' #' @references #' `r print_bib("lundberg_2020")` #' +#' @seealso [MarginalSAGE] [ConditionalSAGE] +#' #' @export SAGE = R6Class( "SAGE", inherit = FeatureImportanceMethod, public = list( - #' @field n_permutations (integer(1)) Number of permutations to sample. + #' @field n_permutations (`integer(1)`) Number of permutations to sample. n_permutations = NULL, - #' @field reference_data (data.table) Reference dataset for marginalization. + #' @field reference_data ([`data.table`][data.table::data.table]) Reference dataset for marginalization. reference_data = NULL, #' @field sampler ([FeatureSampler]) Sampler object for marginalization. sampler = NULL, - #' @field convergence_history ([data.table]) History of SAGE values during computation. + #' @field convergence_history ([`data.table`][data.table::data.table]) History of SAGE values during computation. convergence_history = NULL, #' @field converged (`logical(1)`) Whether convergence was detected. converged = FALSE, @@ -36,11 +46,16 @@ SAGE = R6Class( #' @param task,learner,measure,resampling,features Passed to FeatureImportanceMethod. #' @param n_permutations (`integer(1): 10L`) Number of permutations _per coalition_ to sample for Shapley value estimation. #' The total number of evaluated coalitions is `1 (empty) + n_permutations * n_features`. - #' @param reference_data (`data.table | NULL`) Optional reference dataset. If `NULL`, uses training data. + #' @param reference_data ([`data.table`][data.table::data.table] | `NULL`) Optional reference dataset. If `NULL`, uses training data. #' For each coalition to evaluate, an expanded datasets of size `n_test * n_reference` is created and evaluted in batches of `batch_size`. #' @param batch_size (`integer(1): 5000L`) Maximum number of observations to process in a single prediction call. #' @param sampler ([FeatureSampler]) Sampler for marginalization. Only relevant for `ConditionalSAGE`. #' @param max_reference_size (`integer(1): 100L`) Maximum size of reference dataset. If reference is larger, it will be subsampled. + #' @param early_stopping (`logical(1): FALSE`) Whether to enable early stopping based on convergence detection. + #' @param convergence_threshold (`numeric(1): 0.01`) Relative change threshold for convergence detection. + #' @param se_threshold (`numeric(1): Inf`) Standard error threshold for convergence detection. + #' @param min_permutations (`integer(1): 10L`) Minimum permutations before checking convergence. + #' @param check_interval (`integer(1): 2L`) Check convergence every N permutations. initialize = function( task, learner, @@ -51,7 +66,12 @@ SAGE = R6Class( reference_data = NULL, batch_size = 5000L, sampler = NULL, - max_reference_size = 100L + max_reference_size = 100L, + early_stopping = FALSE, + convergence_threshold = 0.01, + se_threshold = Inf, + min_permutations = 10L, + check_interval = 2L ) { super$initialize( task = task, @@ -99,28 +119,36 @@ SAGE = R6Class( max_reference_size = paradox::p_int(lower = 1L, default = 100L), early_stopping = paradox::p_lgl(default = FALSE), convergence_threshold = paradox::p_dbl(lower = 0, upper = 1, default = 0.01), + se_threshold = paradox::p_dbl(lower = 0, default = Inf), min_permutations = paradox::p_int(lower = 5L, default = 10L), check_interval = paradox::p_int(lower = 1L, default = 2L) ) ps$values$n_permutations = n_permutations ps$values$batch_size = batch_size ps$values$max_reference_size = max_reference_size + ps$values$early_stopping = early_stopping + ps$values$convergence_threshold = convergence_threshold + ps$values$se_threshold = se_threshold + ps$values$min_permutations = min_permutations + ps$values$check_interval = check_interval self$param_set = ps }, #' @description #' Compute SAGE values. - #' @param store_backends (logical(1)) Whether to store backends. - #' @param batch_size (integer(1): 5000L) Maximum number of observations to process in a single prediction call. - #' @param early_stopping (logical(1)) Whether to check for convergence and stop early. - #' @param convergence_threshold (numeric(1)) Relative change threshold for convergence detection. - #' @param min_permutations (integer(1)) Minimum permutations before checking convergence. - #' @param check_interval (integer(1)) Check convergence every N permutations. + #' @param store_backends (`logical(1)`) Whether to store backends. + #' @param batch_size (`integer(1)`: `5000L`) Maximum number of observations to process in a single prediction call. + #' @param early_stopping (`logical(1)`) Whether to check for convergence and stop early. + #' @param convergence_threshold (`numeric(1)`) Relative change threshold for convergence detection. + #' @param se_threshold (`numeric(1)`) Standard error threshold for convergence detection. + #' @param min_permutations (`integer(1)`) Minimum permutations before checking convergence. + #' @param check_interval (`integer(1)`) Check convergence every N permutations. compute = function( store_backends = TRUE, batch_size = NULL, early_stopping = NULL, convergence_threshold = NULL, + se_threshold = NULL, min_permutations = NULL, check_interval = NULL ) { @@ -146,6 +174,11 @@ SAGE = R6Class( self$param_set$values$convergence_threshold, 0.01 ) + se_threshold = resolve_param( + se_threshold, + self$param_set$values$se_threshold, + Inf + ) min_permutations = resolve_param( min_permutations, self$param_set$values$min_permutations, @@ -173,6 +206,7 @@ SAGE = R6Class( batch_size = batch_size, early_stopping = early_stopping, convergence_threshold = convergence_threshold, + se_threshold = se_threshold, min_permutations = min_permutations, check_interval = check_interval ) @@ -193,6 +227,7 @@ SAGE = R6Class( batch_size = batch_size, early_stopping = FALSE, # Only track convergence for first iteration convergence_threshold = convergence_threshold, + se_threshold = se_threshold, min_permutations = min_permutations, check_interval = check_interval ) @@ -221,7 +256,7 @@ SAGE = R6Class( #' @description #' Plot convergence history of SAGE values. #' @param features (`character` | `NULL`) Features to plot. If NULL, plots all features. - #' @return A ggplot2 object + #' @return A [ggplot2][ggplot2::ggplot] object plot_convergence = function(features = NULL) { require_package("ggplot2") @@ -238,8 +273,12 @@ SAGE = R6Class( p = ggplot2::ggplot( plot_data, - ggplot2::aes(x = n_permutations, y = importance, color = feature) + ggplot2::aes(x = n_permutations, y = importance, fill = feature, color = feature) ) + + ggplot2::geom_ribbon( + ggplot2::aes(ymin = importance - se, ymax = importance + se), + alpha = 1 / 3 + ) + ggplot2::geom_line(size = 1) + ggplot2::geom_point(size = 2) + ggplot2::labs( @@ -255,9 +294,10 @@ SAGE = R6Class( }, x = "Number of Permutations", y = "SAGE Value", - color = "Feature" + color = "Feature", + fill = "Feature" ) + - ggplot2::theme_minimal(base_size = 12) + ggplot2::theme_minimal(base_size = 14) if (self$converged) { p = p + @@ -280,16 +320,19 @@ SAGE = R6Class( batch_size = NULL, early_stopping = FALSE, convergence_threshold = 0.01, + se_threshold = Inf, min_permutations = 10L, check_interval = 5L ) { # This function computes the SAGE values for a single resampling iteration. # It iterates through permutations of features, evaluates coalitions, and calculates marginal contributions. - # Initialize a numeric vector to store the sum of marginal contributions for each feature. - # These sums will later be averaged to get the final SAGE values. - sage_values = numeric(length(self$features)) - names(sage_values) = self$features # Name elements by feature names (e.g., c(x1 = 0, x2 = 0, x3 = 0)) + # Initialize numeric vectors to store marginal contributions and their squares for variance calculation. + # We track both sum and sum of squares to calculate running variance and standard errors. + sage_values = numeric(length(self$features)) # Sum of marginal contributions + sage_values_sq = numeric(length(self$features)) # Sum of squared marginal contributions + names(sage_values) = self$features + names(sage_values_sq) = self$features # Pre-generate ALL permutations upfront to ensure consistent random state. # Relevant for reproducibility, especially when using early stopping or parallel processing. @@ -413,7 +456,9 @@ SAGE = R6Class( marginal_contribution = prev_loss - current_loss # Add this marginal contribution to the total SAGE value for the 'feature'. + # Also track the squared contribution for variance calculation. sage_values[feature] = sage_values[feature] + marginal_contribution + sage_values_sq[feature] = sage_values_sq[feature] + marginal_contribution^2 # Update 'prev_loss' for the next iteration in this permutation. # The current coalition's loss becomes the 'previous' loss for the next feature's contribution. @@ -424,15 +469,23 @@ SAGE = R6Class( # Update the count of completed permutations. n_completed = n_completed + checkpoint_size - # Calculate the current average SAGE values based on completed permutations. + # Calculate the current average SAGE values and standard errors based on completed permutations. current_avg = sage_values / n_completed - # Store the current average SAGE values in the convergence history. - # Used for plotting and early stopping. + # Calculate running variance and standard errors for each feature + # Variance = E[X^2] - E[X]^2, SE = sqrt(Var / n) + current_variance = (sage_values_sq / n_completed) - (current_avg^2) + # Ensure variance is non-negative (numerical precision issues) + current_variance[current_variance < 0] = 0 + current_se = sqrt(current_variance / n_completed) + + # Store the current average SAGE values and standard errors in the convergence history. + # Used for plotting, early stopping, and uncertainty quantification. checkpoint_history = data.table( n_permutations = n_completed, feature = names(current_avg), - importance = as.numeric(current_avg) # Ensure numeric, not named vector + importance = as.numeric(current_avg), + se = as.numeric(current_se) ) convergence_history[[length(convergence_history) + 1]] = checkpoint_history @@ -445,20 +498,43 @@ SAGE = R6Class( # Ensure features are in the same order for comparison. prev_values = copy(prev_checkpoint)[order(feature)]$importance curr_values = copy(curr_checkpoint)[order(feature)]$importance + curr_se_values = copy(curr_checkpoint)[order(feature)]$se # Calculate the maximum relative change between current and previous SAGE values. # A small max_change indicates convergence. rel_changes = abs(curr_values - prev_values) / (abs(prev_values) + 1e-8) # Add epsilon to avoid division by zero max_change = max(rel_changes, na.rm = TRUE) - # If the maximum relative change is below the threshold, mark as converged. - if (is.finite(max_change) && max_change < convergence_threshold) { - converged = TRUE - cli::cli_inform(c( + # Calculate maximum standard error across features. + max_se = max(curr_se_values, na.rm = TRUE) + + # SE threshold is already resolved as a parameter to this function + + # Check both relative change and standard error convergence criteria. + rel_change_converged = is.finite(max_change) && max_change < convergence_threshold + se_converged = is.finite(max_se) && max_se < se_threshold + + # Convergence requires both criteria to be met (when SE threshold is finite) + if (is.finite(se_threshold)) { + converged = rel_change_converged && se_converged + convergence_msg = c( + "v" = "SAGE converged after {.val {n_completed}} permutations", + "i" = "Maximum relative change: {.val {round(max_change, 4)}} (threshold: {.val {convergence_threshold}})", + "i" = "Maximum standard error: {.val {round(max_se, 4)}} (threshold: {.val {se_threshold}})", + "i" = "Saved {.val {self$n_permutations - n_completed}} permutations" + ) + } else { + # If SE threshold is infinite, only check relative change + converged = rel_change_converged + convergence_msg = c( "v" = "SAGE converged after {.val {n_completed}} permutations", "i" = "Maximum relative change: {.val {round(max_change, 4)}}", "i" = "Saved {.val {self$n_permutations - n_completed}} permutations" - )) + ) + } + + if (converged) { + cli::cli_inform(convergence_msg) } } } @@ -702,9 +778,11 @@ SAGE = R6Class( #' @title Marginal SAGE #' -#' @description SAGE with marginal sampling (features are marginalized independently). +#' @description [SAGE] with marginal sampling (features are marginalized independently). #' This is the standard SAGE implementation. #' +#' @seealso [ConditionalSAGE] +#' #' @examplesIf requireNamespace("ranger", quietly = TRUE) && requireNamespace("mlr3learners", quietly = TRUE) #' library(mlr3) #' task = tgen("friedman1")$generate(n = 100) @@ -725,11 +803,7 @@ MarginalSAGE = R6Class( public = list( #' @description #' Creates a new instance of the MarginalSAGE class. - #' @param task,learner,measure,resampling,features Passed to [SAGE]. - #' @param n_permutations (integer(1)) Number of permutations to sample. - #' @param reference_data (data.table) Optional reference dataset. - #' @param max_reference_size (integer(1)) Maximum size of reference dataset. - #' @param batch_size (`integer(1): 5000L`) Maximum number of observations to process in a single prediction call. + #' @param task,learner,measure,resampling,features,n_permutations,reference_data,batch_size,max_reference_size,early_stopping,convergence_threshold,se_threshold,min_permutations,check_interval Passed to [SAGE]. initialize = function( task, learner, @@ -739,7 +813,12 @@ MarginalSAGE = R6Class( n_permutations = 10L, reference_data = NULL, batch_size = 5000L, - max_reference_size = 100L + max_reference_size = 100L, + early_stopping = FALSE, + convergence_threshold = 0.01, + se_threshold = Inf, + min_permutations = 10L, + check_interval = 2L ) { # No need to initialize sampler as marginal sampling is done differently here super$initialize( @@ -751,7 +830,12 @@ MarginalSAGE = R6Class( n_permutations = n_permutations, reference_data = reference_data, batch_size = batch_size, - max_reference_size = max_reference_size + max_reference_size = max_reference_size, + early_stopping = early_stopping, + convergence_threshold = convergence_threshold, + se_threshold = se_threshold, + min_permutations = min_permutations, + check_interval = check_interval ) self$label = "Marginal SAGE" @@ -771,8 +855,10 @@ MarginalSAGE = R6Class( #' @title Conditional SAGE #' -#' @description SAGE with conditional sampling (features are marginalized conditionally). -#' Uses ARF by default for conditional marginalization. +#' @description [SAGE] with conditional sampling (features are "marginalized" conditionally). +#' Uses [ARFSampler] as default [ConditionalSampler]. +#' +#' @seealso [MarginalSAGE] #' #' @examplesIf requireNamespace("ranger", quietly = TRUE) && requireNamespace("mlr3learners", quietly = TRUE) && requireNamespace("arf", quietly = TRUE) #' library(mlr3) @@ -794,12 +880,8 @@ ConditionalSAGE = R6Class( public = list( #' @description #' Creates a new instance of the ConditionalSAGE class. - #' @param task,learner,measure,resampling,features Passed to [SAGE]. - #' @param n_permutations (integer(1)) Number of permutations to sample. - #' @param reference_data (data.table) Optional reference dataset. - #' @param sampler ([ConditionalSampler]) Optional custom sampler. Defaults to ARFSampler. - #' @param max_reference_size (integer(1)) Maximum size of reference dataset. - #' @param batch_size (`integer(1): 5000L`) Maximum number of observations to process in a single prediction call. + #' @param task,learner,measure,resampling,features,n_permutations,reference_data,batch_size,max_reference_size,early_stopping,convergence_threshold,se_threshold,min_permutations,check_interval Passed to [SAGE]. + #' @param sampler ([ConditionalSampler]) Optional custom sampler. Defaults to [ARFSampler]. initialize = function( task, learner, @@ -810,7 +892,12 @@ ConditionalSAGE = R6Class( reference_data = NULL, sampler = NULL, batch_size = 5000L, - max_reference_size = 100L + max_reference_size = 100L, + early_stopping = FALSE, + convergence_threshold = 0.01, + se_threshold = Inf, + min_permutations = 10L, + check_interval = 2L ) { # Use ARFSampler by default if (is.null(sampler)) { @@ -832,7 +919,12 @@ ConditionalSAGE = R6Class( reference_data = reference_data, sampler = sampler, batch_size = batch_size, - max_reference_size = max_reference_size + max_reference_size = max_reference_size, + early_stopping = early_stopping, + convergence_threshold = convergence_threshold, + se_threshold = se_threshold, + min_permutations = min_permutations, + check_interval = check_interval ) self$label = "Conditional SAGE" @@ -845,19 +937,19 @@ ConditionalSAGE = R6Class( n_coalitions = length(all_coalitions) n_test = nrow(test_dt) n_reference = nrow(self$reference_data) - + if (xplain_opt("debug")) { cli::cli_inform("Evaluating {.val {length(all_coalitions)}} coalitions (ConditionalSAGE)") } - + # Pre-allocate list for expanded data all_expanded_data = vector("list", n_coalitions) - + # For each coalition, do conditional sampling BEFORE expansion for (i in seq_along(all_coalitions)) { coalition = all_coalitions[[i]] marginalize_features = setdiff(self$features, coalition) - + if (length(marginalize_features) > 0) { # Sample conditionally for unique test instances conditioning_set = coalition @@ -866,55 +958,57 @@ ConditionalSAGE = R6Class( data = test_dt, conditioning_set = conditioning_set ) - + # Create the marginalized test data marginalized_test = copy(test_dt) - marginalized_test[, (marginalize_features) := sampled_data[, marginalize_features, with = FALSE]] + marginalized_test[, + (marginalize_features) := sampled_data[, marginalize_features, with = FALSE] + ] } else { # No marginalization needed marginalized_test = copy(test_dt) } - + # NOW expand with reference data (only once, with correct values) test_expanded = marginalized_test[rep(seq_len(n_test), each = n_reference)] reference_expanded = self$reference_data[rep(seq_len(n_reference), times = n_test)] - + # Add tracking IDs test_expanded[, .coalition_id := i] test_expanded[, .test_instance_id := rep(seq_len(n_test), each = n_reference)] - + all_expanded_data[[i]] = test_expanded } - + # Rest of the method is the same as base implementation combined_data = rbindlist(all_expanded_data) total_rows = nrow(combined_data) - + # Process data in batches if needed if (!is.null(batch_size) && total_rows > batch_size) { n_batches = ceiling(total_rows / batch_size) all_predictions = vector("list", n_batches) - + for (batch_idx in seq_len(n_batches)) { start_row = (batch_idx - 1) * batch_size + 1 end_row = min(batch_idx * batch_size, total_rows) batch_data = combined_data[start_row:end_row] - + if (xplain_opt("debug")) { cli::cli_inform( "Predicting on {.val {nrow(batch_data)}} instances in batch {.val {batch_idx}/{n_batches}}" ) } - + pred_result = learner$predict_newdata(newdata = batch_data, task = self$task) - + if (self$task$task_type == "classif") { all_predictions[[batch_idx]] = pred_result$prob } else { all_predictions[[batch_idx]] = pred_result$response } } - + # Combine predictions if (self$task$task_type == "classif") { combined_predictions = do.call(rbind, all_predictions) @@ -926,16 +1020,16 @@ ConditionalSAGE = R6Class( if (xplain_opt("debug")) { cli::cli_inform("Predicting on {.val {nrow(combined_data)}} instances at once") } - + pred_result = learner$predict_newdata(newdata = combined_data, task = self$task) - + if (self$task$task_type == "classif") { combined_predictions = pred_result$prob } else { combined_predictions = pred_result$response } } - + # Handle NAs in predictions if (self$task$task_type == "classif") { if (any(is.na(combined_predictions))) { @@ -948,27 +1042,27 @@ ConditionalSAGE = R6Class( } else { combined_predictions[is.na(combined_predictions)] = 0 } - + # Aggregate predictions by coalition and test instance if (self$task$task_type == "classif") { n_classes = ncol(combined_predictions) class_names = colnames(combined_predictions) - + for (j in seq_len(n_classes)) { combined_data[, paste0(".pred_class_", j) := combined_predictions[, j]] } - + agg_cols = paste0(".pred_class_", seq_len(n_classes)) avg_preds_by_coalition = combined_data[, lapply(.SD, function(x) mean(x, na.rm = TRUE)), .SDcols = agg_cols, by = .(.coalition_id, .test_instance_id) ] - + setnames(avg_preds_by_coalition, agg_cols, class_names) } else { combined_data[, .prediction := combined_predictions] - + avg_preds_by_coalition = combined_data[, .( avg_pred = mean(.prediction, na.rm = TRUE) @@ -976,16 +1070,16 @@ ConditionalSAGE = R6Class( by = .(.coalition_id, .test_instance_id) ] } - + # Calculate loss for each coalition coalition_losses = numeric(n_coalitions) for (i in seq_len(n_coalitions)) { coalition_data = avg_preds_by_coalition[.coalition_id == i] - + if (self$task$task_type == "classif") { class_names = self$task$class_names prob_matrix = as.matrix(coalition_data[, .SD, .SDcols = class_names]) - + pred_obj = PredictionClassif$new( row_ids = seq_len(n_test), truth = test_dt[[self$task$target_names]], @@ -998,10 +1092,10 @@ ConditionalSAGE = R6Class( response = coalition_data$avg_pred ) } - + coalition_losses[i] = pred_obj$score(self$measure) } - + coalition_losses } ) diff --git a/man/ConditionalSAGE.Rd b/man/ConditionalSAGE.Rd index 726057bf..6c77224b 100644 --- a/man/ConditionalSAGE.Rd +++ b/man/ConditionalSAGE.Rd @@ -4,8 +4,8 @@ \alias{ConditionalSAGE} \title{Conditional SAGE} \description{ -SAGE with conditional sampling (features are marginalized conditionally). -Uses ARF by default for conditional marginalization. +\link{SAGE} with conditional sampling (features are "marginalized" conditionally). +Uses \link{ARFSampler} as default \link{ConditionalSampler}. } \examples{ \dontshow{if (requireNamespace("ranger", quietly = TRUE) && requireNamespace("mlr3learners", quietly = TRUE) && requireNamespace("arf", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} @@ -23,6 +23,9 @@ sage$compute() sage$compute(batch_size = 1000) \dontshow{\}) # examplesIf} } +\seealso{ +\link{MarginalSAGE} +} \section{Super classes}{ \code{\link[xplainfi:FeatureImportanceMethod]{xplainfi::FeatureImportanceMethod}} -> \code{\link[xplainfi:SAGE]{xplainfi::SAGE}} -> \code{ConditionalSAGE} } @@ -60,24 +63,21 @@ Creates a new instance of the ConditionalSAGE class. reference_data = NULL, sampler = NULL, batch_size = 5000L, - max_reference_size = 100L + max_reference_size = 100L, + early_stopping = FALSE, + convergence_threshold = 0.01, + se_threshold = Inf, + min_permutations = 10L, + check_interval = 2L )}\if{html}{\out{}} } \subsection{Arguments}{ \if{html}{\out{