Title: | Random Planted Forest: A Directly Interpretable Tree Ensemble |
---|---|
Description: | An implementation of the Random Planted Forest algorithm for directly interpretable tree ensembles based on a functional ANOVA decomposition. |
Authors: | Joseph Theo Meyer [aut], Munir Hiabu [aut], Maike Spankus [aut], Marvin N. Wright [aut], Lukas Burk [cre, aut] |
Maintainer: | Lukas Burk <[email protected]> |
License: | Apache License (>= 2) |
Version: | 0.2.1.9000 |
Built: | 2024-11-04 13:42:25 UTC |
Source: | https://github.com/PlantedML/randomPlantedForest |
Prediction components are a functional decomposition of the model prediction. The sum of all components equals the overall predicted value for an observation.
predict_components(object, new_data, max_interaction = NULL, predictors = NULL)
predict_components(object, new_data, max_interaction = NULL, predictors = NULL)
object |
A fit object of class |
new_data |
Data for new observations to predict. |
max_interaction |
|
predictors |
|
Extracts all possible components up to max_interaction
degrees,
up to the value set when calling rpf()
. The intercept is always included.
Optionally predictors
can be specified to only include components including the given variables.
If max_interaction
is greater than length(predictors)
, the max_interaction
will be lowered accordingly.
A list
with elements:
m
(data.table
): Components for each main effect and
interaction term, representing the functional decomposition of the prediction.
All components together with the intercept sum up
to the prediction.
For multiclass classification, the number of output columns is multiplied by
the number of levels in the outcome.
intercept
(numeric(1)
): Expected value of the prediction.
x
(data.table
): Copy of new_data
containing predictors selected
by predictors
.
target_levels
(character
): For multiclass classification only: Vector of target levels
which can be used to disassemble m
, as names include both term and target level.
Depending on the number of predictors and max_interaction
, the number of components will
increase drastically to sum(choose(ncol(new_data), seq_len(max_interaction)))
.
# Regression task, only some predictors train <- mtcars[1:20, 1:4] test <- mtcars[21:32, 1:4] set.seed(23) rpfit <- rpf(mpg ~ ., data = train, max_interaction = 3, ntrees = 30) # Extract all components, including main effects and interaction terms up to `max_interaction` (components <- predict_components(rpfit, test)) # sums to prediction cbind( m_sum = rowSums(components$m) + components$intercept, prediction = predict(rpfit, test) ) # Only get components with interactions of a lower degree, ignoring 3-way interactions predict_components(rpfit, test, max_interaction = 2) # Only retrieve main effects (main_effects <- predict_components(rpfit, test, max_interaction = 1)) # The difference is the combined contribution of interaction effects cbind( m_sum = rowSums(main_effects$m) + main_effects$intercept, prediction = predict(rpfit, test) )
# Regression task, only some predictors train <- mtcars[1:20, 1:4] test <- mtcars[21:32, 1:4] set.seed(23) rpfit <- rpf(mpg ~ ., data = train, max_interaction = 3, ntrees = 30) # Extract all components, including main effects and interaction terms up to `max_interaction` (components <- predict_components(rpfit, test)) # sums to prediction cbind( m_sum = rowSums(components$m) + components$intercept, prediction = predict(rpfit, test) ) # Only get components with interactions of a lower degree, ignoring 3-way interactions predict_components(rpfit, test, max_interaction = 2) # Only retrieve main effects (main_effects <- predict_components(rpfit, test, max_interaction = 1)) # The difference is the combined contribution of interaction effects cbind( m_sum = rowSums(main_effects$m) + main_effects$intercept, prediction = predict(rpfit, test) )
Random Planted Forest Predictions
## S3 method for class 'rpf' predict( object, new_data, type = ifelse(object$mode == "regression", "numeric", "prob"), ... )
## S3 method for class 'rpf' predict( object, new_data, type = ifelse(object$mode == "regression", "numeric", "prob"), ... )
object |
A fit object of class |
new_data |
Data for new observations to predict. |
type |
For classification and If |
... |
Unused. |
For regression: A tbl
with column .pred
with
the same number of rows as new_data
.
For classification: A tbl
with one column for each
level in y
containing class probabilities if type = "prob"
.
For type = "class"
, one column .pred
with class predictions is returned.
For type = "numeric"
or "link"
, one column .pred
with raw predictions.
# Regression with L2 loss rpfit <- rpf(y = mtcars$mpg, x = mtcars[, c("cyl", "wt")]) predict(rpfit, mtcars[, c("cyl", "wt")])
# Regression with L2 loss rpfit <- rpf(y = mtcars$mpg, x = mtcars[, c("cyl", "wt")]) predict(rpfit, mtcars[, c("cyl", "wt")])
Print an rpf fit
## S3 method for class 'rpf' print(x, ...)
## S3 method for class 'rpf' print(x, ...)
x |
And object of class |
... |
Further arguments passed to or from other methods. |
Invisibly: x
.
rpf
.
rpf(mpg ~ cyl + wt + drat, data = mtcars, max_interaction = 2, ntrees = 10)
rpf(mpg ~ cyl + wt + drat, data = mtcars, max_interaction = 2, ntrees = 10)
These methods are provided to avoid flooding the console with long nested lists containing tree structures. Note
## S3 method for class 'rpf_forest' print(x, ...) ## S3 method for class 'rpf_forest' str(object, ...)
## S3 method for class 'rpf_forest' print(x, ...) ## S3 method for class 'rpf_forest' str(object, ...)
x |
Object of class |
... |
Further arguments passed to or from other methods. |
object |
Object of class |
rpfit <- rpf(mpg ~ cyl + wt, data = mtcars, ntrees = 10) print(rpfit$forest) str(rpfit$forest)
rpfit <- rpf(mpg ~ cyl + wt, data = mtcars, ntrees = 10) print(rpfit$forest) str(rpfit$forest)
TODO: Explain what this does
purify(x, ...) ## Default S3 method: purify(x, ...) ## S3 method for class 'rpf' purify(x, ...) is_purified(x)
purify(x, ...) ## Default S3 method: purify(x, ...) ## S3 method for class 'rpf' purify(x, ...) is_purified(x)
x |
And object of class |
... |
(Unused) |
Unless rpf()
is called with purify = TRUE
, the forest has to be purified after fit
to ensure the components extracted by predict_components()
are valid.
predict_components()
will automatically purify a forest if is_purified()
reports FALSE
.
Invisibly: The rpf
object.
rpfit <- rpf(mpg ~., data = mtcars, max_interaction = 2, ntrees = 10) purify(rpfit)
rpfit <- rpf(mpg ~., data = mtcars, max_interaction = 2, ntrees = 10) purify(rpfit)
Random Planted Forest
rpf(x, ...) ## S3 method for class 'data.frame' rpf( x, y, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... ) ## S3 method for class 'matrix' rpf( x, y, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... ) ## S3 method for class 'formula' rpf( formula, data, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... ) ## S3 method for class 'recipe' rpf( x, data, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... )
rpf(x, ...) ## S3 method for class 'data.frame' rpf( x, y, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... ) ## S3 method for class 'matrix' rpf( x, y, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... ) ## S3 method for class 'formula' rpf( formula, data, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... ) ## S3 method for class 'recipe' rpf( x, data, max_interaction = 1, ntrees = 50, splits = 30, split_try = 10, t_try = 0.4, deterministic = FALSE, nthreads = 1, purify = FALSE, cv = FALSE, loss = "L2", delta = 0, epsilon = 0.1, ... )
x , data
|
Feature |
... |
(Unused). |
y |
Target vector for use with |
max_interaction |
|
ntrees |
|
splits |
|
split_try |
|
t_try |
|
deterministic |
|
nthreads |
|
purify |
|
cv |
|
loss |
|
delta |
|
epsilon |
|
formula |
Formula specification, e.g. y ~ x1 + x2. |
Object of class "rpf"
with model object contained in $fit
.
# Regression with x and y rpfit <- rpf(x = mtcars[, c("cyl", "wt")], y = mtcars$mpg) # Regression with formula rpfit <- rpf(mpg ~ cyl + wt, data = mtcars)
# Regression with x and y rpfit <- rpf(x = mtcars[, c("cyl", "wt")], y = mtcars$mpg) # Regression with formula rpfit <- rpf(mpg ~ cyl + wt, data = mtcars)