| Title: | Global Explanations for Tree-Based Models |
|---|---|
| Description: | Global explanations for tree-based models by decomposing regression or classification functions into the sum of main components and interaction components of arbitrary order. Calculates SHAP values and q-interaction SHAP for all values of q for tree-based models such as xgboost. |
| Authors: | Marvin N. Wright [aut, cre] (ORCID: <https://orcid.org/0000-0002-8542-6291>), Joseph Theo Meyer [aut], Munir Hiabu [aut], Lukas Burk [aut] (ORCID: <https://orcid.org/0000-0001-7528-3795>), Jinyang Liu [aut] (ORCID: <https://orcid.org/0009-0005-3167-9014>) |
| Maintainer: | Marvin N. Wright <[email protected]> |
| License: | GPL-3 |
| Version: | 0.5.2 |
| Built: | 2026-05-22 09:03:11 UTC |
| Source: | https://github.com/PlantedML/glex |
Plotting the main effects among the prediction components is effectively identical to a partial dependence plot, centered to 0.
## S3 method for class 'glex' autoplot(object, predictors, ...) plot_main_effect(object, predictor, rug_sides = "b", ...) plot_threeway_effects(object, predictors, rug_sides = "b", ...) plot_twoway_effects(object, predictors, rug_sides = "b", ...)## S3 method for class 'glex' autoplot(object, predictors, ...) plot_main_effect(object, predictor, rug_sides = "b", ...) plot_threeway_effects(object, predictors, rug_sides = "b", ...) plot_twoway_effects(object, predictors, rug_sides = "b", ...)
object |
Object of class |
... |
Used for future expansion. |
predictor, predictors
|
|
rug_sides |
|
A ggplot2 object.
Other Visualization functions:
autoplot.glex_vi(),
glex_explain(),
plot_pdp()
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Main effects ---- plot_main_effect(components, "wt") plot_main_effect(components, "cyl") } # plot_threeway_effects(components, c("hr", "temp", "workingday")) if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # 2-degree interaction effects ---- # 2d continuous, scatterplot of arbitrary orientation plot_twoway_effects(components, c("wt", "drat")) # flipped: plot_twoway_effects(components, c("drat", "wt")) # continuous + categorical (forces continuous on x axis, colors by categorical) plot_twoway_effects(components, c("wt", "cyl")) # identical: plot_twoway_effects(components, c("cyl", "wt")) # 2d categorical, heatmap of arbitrary orientation plot_twoway_effects(components, c("vs", "cyl")) plot_twoway_effects(components, c("cyl", "vs")) }if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Main effects ---- plot_main_effect(components, "wt") plot_main_effect(components, "cyl") } # plot_threeway_effects(components, c("hr", "temp", "workingday")) if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # 2-degree interaction effects ---- # 2d continuous, scatterplot of arbitrary orientation plot_twoway_effects(components, c("wt", "drat")) # flipped: plot_twoway_effects(components, c("drat", "wt")) # continuous + categorical (forces continuous on x axis, colors by categorical) plot_twoway_effects(components, c("wt", "cyl")) # identical: plot_twoway_effects(components, c("cyl", "wt")) # 2d categorical, heatmap of arbitrary orientation plot_twoway_effects(components, c("vs", "cyl")) plot_twoway_effects(components, c("cyl", "vs")) }
Plot glex Variable Importances
## S3 method for class 'glex_vi' autoplot( object, by_degree = FALSE, threshold = 0, max_interaction = NULL, scale = "absolute", ... )## S3 method for class 'glex_vi' autoplot( object, by_degree = FALSE, threshold = 0, max_interaction = NULL, scale = "absolute", ... )
object |
Object of class |
by_degree |
( |
threshold |
( |
max_interaction |
( |
scale |
( |
... |
(Unused) |
A ggplot object.
Other Visualization functions:
autoplot.glex(),
glex_explain(),
plot_pdp()
A reduced version of the Bikeshare data as included with ISLR2.
The dataset has been converted to a data.table::data.table(), with the following changes:
bikebike
An object of class data.table (inherits from data.frame) with 8645 rows and 11 columns.
hr has been copnverted to a numeric
workingday was recoded to a binary factor with labels c("No Workingday", "Workingday")
season was recoded to a factor with labels c("Winter", "Spring", "Summer", "Fall")
Variables atemp, day, registered and casual were removed
Bikeshare in package ISLR2
Global explanations for tree-based models by decomposing regression or classification functions into the sum of main components and interaction components of arbitrary order. Calculates SHAP values and q-interaction SHAP for all values of q for tree-based models such as xgboost.
glex(object, x, max_interaction = NULL, features = NULL, ...) ## Default S3 method: glex(object, ...) ## S3 method for class 'rpf' glex(object, x, max_interaction = NULL, features = NULL, ...) ## S3 method for class 'xgb.Booster' glex( object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ... ) ## S3 method for class 'ranger' glex( object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ... )glex(object, x, max_interaction = NULL, features = NULL, ...) ## Default S3 method: glex(object, ...) ## S3 method for class 'rpf' glex(object, x, max_interaction = NULL, features = NULL, ...) ## S3 method for class 'xgb.Booster' glex( object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ... ) ## S3 method for class 'ranger' glex( object, x, max_interaction = NULL, features = NULL, max_background_sample_size = NULL, weighting_method = "fastpd", ... )
object |
Model to be explained, either of class |
x |
Data to be explained. |
max_interaction |
( |
features |
Vector of column names in x to calculate components for. Default is |
... |
Further arguments passed to methods. |
max_background_sample_size |
The maximum number of background samples used for the FastPD algorithm, only used when |
weighting_method |
Use either "path-dependent", "fastpd" (default), or "empirical". See References for details. |
For parallel execution using xgboost models, register a backend, e.g. with
doParallel::registerDoParallel().
The different weighting methods are described in detail in Liu et al. (2024). The default method is "fastpd" as it consistently estimates the correct partial dependence function.
Decomposition of the regression or classification function.
A list with elements:
shap: SHAP values (xgboost method only).
m: Functional decomposition into all main and interaction
components in the model, up to the degree specified by max_interaction.
The variable names correspond to the original variable names,
with : separating interaction terms as one would specify in a formula interface.
intercept: Intercept term, the expected value of the prediction.
Liu, J., Steensgaard, T., Wright, M. N., Pfister, N., & Hiabu, M. (2024). Fast Estimation of Partial Dependence Functions using Trees. arXiv preprint arXiv:2410.13448.
# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) str(glex_rpf, list.len = 5) } # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex(xg, x[27:32, ]) glex(xg, mtcars[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(xg, x[27:32, ]) ## End(Not run) } # ranger ----- if (requireNamespace("ranger", quietly = TRUE)) { library(ranger) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg rf <- ranger(x = x[1:26, ], y = y[1:26], num.trees = 5, max.depth = 3, node.stats = TRUE) glex(rf, x[27:32, ]) glex(rf, mtcars[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(rf, x[27:32, ]) ## End(Not run) }# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) str(glex_rpf, list.len = 5) } # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex(xg, x[27:32, ]) glex(xg, mtcars[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(xg, x[27:32, ]) ## End(Not run) } # ranger ----- if (requireNamespace("ranger", quietly = TRUE)) { library(ranger) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg rf <- ranger(x = x[1:26, ], y = y[1:26], num.trees = 5, max.depth = 3, node.stats = TRUE) glex(rf, x[27:32, ]) glex(rf, mtcars[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(rf, x[27:32, ]) ## End(Not run) }
Plots the prediction components for a single observation, identified by the row number in the dataset used
with glex().
Since the resulting plot can be quite busy due to potentially large amounts of elements, it is highly
recommended to use predictors, max_interaction, or threshold to restrict the number of
elements in the plot.
glex_explain( object, id, threshold = 0, max_interaction = NULL, predictors = NULL, class = NULL, barheight = 0.5 )glex_explain( object, id, threshold = 0, max_interaction = NULL, predictors = NULL, class = NULL, barheight = 0.5 )
object |
Object of class |
id |
( |
threshold |
( |
max_interaction |
( |
predictors |
( |
class |
( |
barheight |
( |
A ggplot object.
Other Visualization functions:
autoplot.glex(),
autoplot.glex_vi(),
plot_pdp()
set.seed(1) # Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) glex_explain(glex_rpf, id = 3, predictors = "hp", threshold = 0.01) }set.seed(1) # Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) glex_explain(glex_rpf, id = 3, predictors = "hp", threshold = 0.01) }
Variable Importance for Main and Interaction Terms
glex_vi(object, ...)glex_vi(object, ...)
object |
Object of class |
... |
(Unused) |
The m reported here is the average absolute value of m as reported by glex(), aggregated by term:
In turn, m_rel rescales m by the average prediction of the model (, intercept as reported by glex()):
A data.table::data.table() with columns:
degree (integer): Degree of interaction of the term, with 1 being main effects,
2 being 2-degree interactions etc.
term (character): Model term, e.g. main effect x1 or interaction term x1:x2, x1:x3:x5 etc.
class (factor): For multiclass targets only: The associated target class. Lists all classes in the
target, not limited to the majority vote.
m (numeric): Average absolute contribution of term, see Details.
m_rel (numeric): m but relative to the average prediction (intercept in glex() output).
set.seed(1) # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex_xgb <- glex(xg, x[27:32, ]) vi_xgb <- glex_vi(glex_xgb) library(ggplot2) autoplot(vi_xgb) autoplot(vi_xgb, by_degree = TRUE) }set.seed(1) # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex_xgb <- glex(xg, x[27:32, ]) vi_xgb <- glex_vi(glex_xgb) library(ggplot2) autoplot(vi_xgb) autoplot(vi_xgb, by_degree = TRUE) }
A version of plot_main_effect with the intercept term (horizontal line) added,
resulting in a partial dependence plot.
plot_pdp(object, predictor, rug_sides = "b", ...)plot_pdp(object, predictor, rug_sides = "b", ...)
object |
Object of class |
predictor |
|
rug_sides |
|
... |
Used for future expansion. |
A ggplot2 object.
Other Visualization functions:
autoplot.glex(),
autoplot.glex_vi(),
glex_explain()
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) plot_pdp(components, "wt") plot_pdp(components, "cyl") }if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) plot_pdp(components, "wt") plot_pdp(components, "cyl") }
This is implemented mainly to avoid flooding the console in cases where the glex object
uses many terms, which leads to a large amount of column names of $m being printed to the console.
This function wraps str() with a truncated output for a more compact representation.
## S3 method for class 'glex' print(x, ...)## S3 method for class 'glex' print(x, ...)
x |
Object to print. |
... |
(Unused) |
# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) glex(rp, mtcars[27:32, ]) }# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) glex(rp, mtcars[27:32, ]) }
Subset components
subset_components(components, term) subset_component_names(components, term)subset_components(components, term) subset_component_names(components, term)
components |
An object of class |
term |
( |
subset_components: An object of class glex.
subset_component_names: A character vector.
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Get component object with only "hp" and its interactions subset_components(components, "hp") subset_component_names(components, "hp") }if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Get component object with only "hp" and its interactions subset_components(components, "hp") subset_component_names(components, "hp") }
This is a slight variation of ggplot2::theme_minimal() with increased font size.
theme_glex( base_size = 13, base_family = "", base_line_size = base_size/22, base_rect_size = base_size/22, grid_x = TRUE, grid_y = FALSE )theme_glex( base_size = 13, base_family = "", base_line_size = base_size/22, base_rect_size = base_size/22, grid_x = TRUE, grid_y = FALSE )
base_size |
( |
base_family |
( |
base_line_size, base_rect_size
|
( |
grid_x |
( |
grid_y |
( |
A ggplot2 theme object
library(ggplot2) ggplot(mtcars, aes(wt, mpg)) + geom_point() + theme_glex()library(ggplot2) ggplot(mtcars, aes(wt, mpg)) + geom_point() + theme_glex()