Title: | Global Explanations for Tree-Based Models |
---|---|
Description: | Global explanations for tree-based models by decomposing regression or classification functions into the sum of main components and interaction components of arbitrary order. Calculates SHAP values and q-interaction SHAP for all values of q for tree-based models such as xgboost. |
Authors: | Marvin N. Wright [aut, cre] , Joseph Theo Meyer [aut], Munir Hiabu [aut], Lukas Burk [aut] , Jinyang Liu [aut] |
Maintainer: | Marvin N. Wright <[email protected]> |
License: | GPL-3 |
Version: | 0.4.2 |
Built: | 2024-12-21 06:12:52 UTC |
Source: | https://github.com/PlantedML/glex |
Plotting the main effects among the prediction components is effectively identical to a partial dependence plot, centered to 0.
## S3 method for class 'glex' autoplot(object, predictors, ...) plot_main_effect(object, predictor, rug_sides = "b", ...) plot_threeway_effects(object, predictors, rug_sides = "b", ...) plot_twoway_effects(object, predictors, rug_sides = "b", ...)
## S3 method for class 'glex' autoplot(object, predictors, ...) plot_main_effect(object, predictor, rug_sides = "b", ...) plot_threeway_effects(object, predictors, rug_sides = "b", ...) plot_twoway_effects(object, predictors, rug_sides = "b", ...)
object |
Object of class |
... |
Used for future expansion. |
predictor , predictors
|
|
rug_sides |
|
A ggplot2
object.
Other Visualization functions:
autoplot.glex_vi()
,
glex_explain()
,
plot_pdp()
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Main effects ---- plot_main_effect(components, "wt") plot_main_effect(components, "cyl") } # plot_threeway_effects(components, c("hr", "temp", "workingday")) if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # 2-degree interaction effects ---- # 2d continuous, scatterplot of arbitrary orientation plot_twoway_effects(components, c("wt", "drat")) # flipped: plot_twoway_effects(components, c("drat", "wt")) # continuous + categorical (forces continuous on x axis, colors by categorical) plot_twoway_effects(components, c("wt", "cyl")) # identical: plot_twoway_effects(components, c("cyl", "wt")) # 2d categorical, heatmap of arbitrary orientation plot_twoway_effects(components, c("vs", "cyl")) plot_twoway_effects(components, c("cyl", "vs")) }
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Main effects ---- plot_main_effect(components, "wt") plot_main_effect(components, "cyl") } # plot_threeway_effects(components, c("hr", "temp", "workingday")) if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # 2-degree interaction effects ---- # 2d continuous, scatterplot of arbitrary orientation plot_twoway_effects(components, c("wt", "drat")) # flipped: plot_twoway_effects(components, c("drat", "wt")) # continuous + categorical (forces continuous on x axis, colors by categorical) plot_twoway_effects(components, c("wt", "cyl")) # identical: plot_twoway_effects(components, c("cyl", "wt")) # 2d categorical, heatmap of arbitrary orientation plot_twoway_effects(components, c("vs", "cyl")) plot_twoway_effects(components, c("cyl", "vs")) }
Plot glex Variable Importances
## S3 method for class 'glex_vi' autoplot( object, by_degree = FALSE, threshold = 0, max_interaction = NULL, scale = "absolute", ... )
## S3 method for class 'glex_vi' autoplot( object, by_degree = FALSE, threshold = 0, max_interaction = NULL, scale = "absolute", ... )
object |
Object of class |
by_degree |
( |
threshold |
( |
max_interaction |
( |
scale |
( |
... |
(Unused) |
A ggplot
object.
Other Visualization functions:
autoplot.glex()
,
glex_explain()
,
plot_pdp()
A reduced version of the Bikeshare
data as included with ISLR2
.
The dataset has been converted to a data.table
, with the following changes:
bike
bike
An object of class data.table
(inherits from data.frame
) with 8645 rows and 11 columns.
hr
has been copnverted to a numeric
workingday
was recoded to a binary factor
with labels c("No Workingday", "Workingday")
season
was recoded to a factor
with labels c("Winter", "Spring", "Summer", "Fall")
Variables atemp
, day
, registered
and casual
were removed
Bikeshare
in package ISLR2
Global explanations for tree-based models by decomposing regression or classification functions into the sum of main components and interaction components of arbitrary order. Calculates SHAP values and q-interaction SHAP for all values of q for tree-based models such as xgboost.
glex( object, x, max_interaction = NULL, features = NULL, probFunction = NULL, ... ) ## S3 method for class 'rpf' glex(object, x, max_interaction = NULL, features = NULL, ...) ## S3 method for class 'xgb.Booster' glex( object, x, max_interaction = NULL, features = NULL, probFunction = NULL, ... ) ## S3 method for class 'ranger' glex( object, x, max_interaction = NULL, features = NULL, probFunction = NULL, ... )
glex( object, x, max_interaction = NULL, features = NULL, probFunction = NULL, ... ) ## S3 method for class 'rpf' glex(object, x, max_interaction = NULL, features = NULL, ...) ## S3 method for class 'xgb.Booster' glex( object, x, max_interaction = NULL, features = NULL, probFunction = NULL, ... ) ## S3 method for class 'ranger' glex( object, x, max_interaction = NULL, features = NULL, probFunction = NULL, ... )
object |
Model to be explained, either of class |
x |
Data to be explained. |
max_interaction |
( |
features |
Vector of column names in x to calculate components for. Default is |
probFunction |
Either "path-dependent" to use old path-dependent weighting of leaves or a user specified probability function of the signature function(coords, lb, ub). Defaults to |
... |
Further arguments passed to methods. |
For parallel execution using xgboost
models, register a backend, e.g. with
doParallel::registerDoParallel()
.
Decomposition of the regression or classification function.
A list
with elements:
shap
: SHAP values (xgboost
method only).
m
: Functional decomposition into all main and interaction
components in the model, up to the degree specified by max_interaction
.
The variable names correspond to the original variable names,
with :
separating interaction terms as one would specify in a formula
interface.
intercept
: Intercept term, the expected value of the prediction.
# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) str(glex_rpf, list.len = 5) } # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex(xg, x[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(xg, x[27:32, ]) ## End(Not run) } # ranger ----- if (requireNamespace("ranger", quietly = TRUE)) { library(ranger) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg rf <- ranger(x = x[1:26, ], y = y[1:26], num.trees = 5, max.depth = 3, node.stats = TRUE) glex(rf, x[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(rf, x[27:32, ]) ## End(Not run) }
# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) str(glex_rpf, list.len = 5) } # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex(xg, x[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(xg, x[27:32, ]) ## End(Not run) } # ranger ----- if (requireNamespace("ranger", quietly = TRUE)) { library(ranger) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg rf <- ranger(x = x[1:26, ], y = y[1:26], num.trees = 5, max.depth = 3, node.stats = TRUE) glex(rf, x[27:32, ]) ## Not run: # Parallel execution doParallel::registerDoParallel() glex(rf, x[27:32, ]) ## End(Not run) }
Plots the prediction components for a single observation, identified by the row number in the dataset used
with glex()
.
Since the resulting plot can be quite busy due to potentially large amounts of elements, it is highly
recommended to use predictors
, max_interaction
, or threshold
to restrict the number of
elements in the plot.
glex_explain( object, id, threshold = 0, max_interaction = NULL, predictors = NULL, class = NULL, barheight = 0.5 )
glex_explain( object, id, threshold = 0, max_interaction = NULL, predictors = NULL, class = NULL, barheight = 0.5 )
object |
Object of class |
id |
( |
threshold |
( |
max_interaction |
( |
predictors |
( |
class |
( |
barheight |
( |
A ggplot object.
Other Visualization functions:
autoplot.glex()
,
autoplot.glex_vi()
,
plot_pdp()
set.seed(1) # Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) glex_explain(glex_rpf, id = 3, predictors = "hp", threshold = 0.01) }
set.seed(1) # Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2) glex_rpf <- glex(rp, mtcars[27:32, ]) glex_explain(glex_rpf, id = 3, predictors = "hp", threshold = 0.01) }
Variable Importance for Main and Interaction Terms
glex_vi(object, ...)
glex_vi(object, ...)
object |
Object of class |
... |
(Unused) |
The m
reported here is the average absolute value of m
as reported by glex()
, aggregated by term
:
In turn, m_rel
rescales m
by the average prediction of the model (,
intercept
as reported by glex()
):
A data.table
with columns:
degree
(integer
): Degree of interaction of the term
, with 1
being main effects,
2
being 2-degree interactions etc.
term
(character
): Model term, e.g. main effect x1
or interaction term x1:x2
, x1:x3:x5
etc.
class
(factor
): For multiclass targets only: The associated target class. Lists all classes in the
target, not limited to the majority vote.
m
(numeric
): Average absolute contribution of term
, see Details.
m_rel
(numeric
): m
but relative to the average prediction (intercept
in glex()
output).
set.seed(1) # Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) glex_rpf <- glex(rp, mtcars[27:32, ]) # All terms vi_rpf <- glex_vi(glex_rpf) library(ggplot2) # Filter to contributions greater 0.05 on the scale of the target autoplot(vi_rpf, threshold = 0.05) # Summarize by degree of interaction autoplot(vi_rpf, by_degree = TRUE) # Filter by relative contributions greater 0.1% autoplot(vi_rpf, scale = "relative", threshold = 0.001) } # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex_xgb <- glex(xg, x[27:32, ]) vi_xgb <- glex_vi(glex_xgb) library(ggplot2) autoplot(vi_xgb) autoplot(vi_xgb, by_degree = TRUE) }
set.seed(1) # Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 3) glex_rpf <- glex(rp, mtcars[27:32, ]) # All terms vi_rpf <- glex_vi(glex_rpf) library(ggplot2) # Filter to contributions greater 0.05 on the scale of the target autoplot(vi_rpf, threshold = 0.05) # Summarize by degree of interaction autoplot(vi_rpf, by_degree = TRUE) # Filter by relative contributions greater 0.1% autoplot(vi_rpf, scale = "relative", threshold = 0.001) } # xgboost ----- if (requireNamespace("xgboost", quietly = TRUE)) { library(xgboost) x <- as.matrix(mtcars[, -1]) y <- mtcars$mpg xg <- xgboost(data = x[1:26, ], label = y[1:26], params = list(max_depth = 4, eta = .1), nrounds = 10, verbose = 0) glex_xgb <- glex(xg, x[27:32, ]) vi_xgb <- glex_vi(glex_xgb) library(ggplot2) autoplot(vi_xgb) autoplot(vi_xgb, by_degree = TRUE) }
A version of plot_main_effect
with the intercept term (horizontal line) added,
resulting in a partial dependence plot.
plot_pdp(object, predictor, rug_sides = "b", ...)
plot_pdp(object, predictor, rug_sides = "b", ...)
object |
Object of class |
predictor |
|
rug_sides |
|
... |
Used for future expansion. |
A ggplot2
object.
Other Visualization functions:
autoplot.glex()
,
autoplot.glex_vi()
,
glex_explain()
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) plot_pdp(components, "wt") plot_pdp(components, "cyl") }
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) plot_pdp(components, "wt") plot_pdp(components, "cyl") }
This is implemented mainly to avoid flooding the console in cases where the glex
object
uses many terms, which leads to a large amount of column names of $m
being printed to the console.
This function wraps str()
with a truncated output for a more compact representation.
## S3 method for class 'glex' print(x, ...)
## S3 method for class 'glex' print(x, ...)
x |
Object to print. |
... |
(Unused) |
# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) glex(rp, mtcars[27:32, ]) }
# Random Planted Forest ----- if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) rp <- rpf(mpg ~ hp + wt + drat, data = mtcars[1:26, ], max_interaction = 2) glex(rp, mtcars[27:32, ]) }
Subset components
subset_components(components, term) subset_component_names(components, term)
subset_components(components, term) subset_component_names(components, term)
components |
An object of class |
term |
( |
subset_components
: An object of class glex
.
subset_component_names
: A character vector.
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Get component object with only "hp" and its interactions subset_components(components, "hp") subset_component_names(components, "hp") }
if (requireNamespace("randomPlantedForest", quietly = TRUE)) { library(randomPlantedForest) # introduce factor variables to show categorical feature handling mtcars$cyl <- factor(mtcars$cyl) mtcars$vs <- factor(mtcars$vs) # Fit forest, get components set.seed(12) rpfit <- rpf(mpg ~ cyl + wt + hp + drat + vs, data = mtcars, ntrees = 25, max_interaction = 3) components <- glex(rpfit, mtcars) # Get component object with only "hp" and its interactions subset_components(components, "hp") subset_component_names(components, "hp") }
This is a slight variation of ggplot2::theme_minimal()
with increased font size.
theme_glex( base_size = 13, base_family = "", base_line_size = base_size/22, base_rect_size = base_size/22, grid_x = TRUE, grid_y = FALSE )
theme_glex( base_size = 13, base_family = "", base_line_size = base_size/22, base_rect_size = base_size/22, grid_x = TRUE, grid_y = FALSE )
base_size |
( |
base_family |
( |
base_line_size , base_rect_size
|
( |
grid_x |
( |
grid_y |
( |
A ggplot2
theme object
library(ggplot2) ggplot(mtcars, aes(wt, mpg)) + geom_point() + theme_glex()
library(ggplot2) ggplot(mtcars, aes(wt, mpg)) + geom_point() + theme_glex()