| Function | Works |
|---|---|
tidypredict_fit(), tidypredict_sql(),
parse_model() |
✔ |
tidypredict_to_column() |
✗ |
tidypredict_test() |
✔ |
tidypredict_interval(),
tidypredict_sql_interval() |
✗ |
parsnip |
✔ |
Here is a simple ranger() model using the
mtcars dataset:
The parser is based on the output from the
ranger::treeInfo() function. It will return as many
decision paths as there are non-NA rows in the prediction
field.
treeInfo(model) %>%
head()
#> nodeID leftChild rightChild splitvarID splitvarName splitval terminal
#> 1 0 1 2 8 gear 3.50 FALSE
#> 2 1 3 4 2 hp 192.50 FALSE
#> 3 2 5 6 4 wt 2.26 FALSE
#> 4 3 NA NA NA <NA> NA TRUE
#> 5 4 NA NA NA <NA> NA TRUE
#> 6 5 NA NA NA <NA> NA TRUE
#> prediction
#> 1 NA
#> 2 NA
#> 3 NA
#> 4 16.02000
#> 5 12.18333
#> 6 29.98750The output from parse_model() is transformed into a
dplyr, a.k.a Tidy Eval, formula. Each decision tree becomes
one dplyr::case_when() statement, which are then
combined.
tidypredict_fit(model)
#> (case_when(gear <= 3.5 ~ case_when(hp <= 192.5 ~ 16.02, .default = 12.1833333333333),
#> .default = case_when(wt <= 2.26 ~ 29.9875, .default = 20.0076923076923)) +
#> case_when(wt <= 3.295 ~ case_when(vs <= 0.5 ~ 21.1833333333333,
#> .default = 25.8714285714286), .default = case_when(qsec <=
#> 18.15 ~ 14.1588235294118, .default = 18.5)) + case_when(disp <=
#> 163.8 ~ case_when(hp <= 79.5 ~ 28.125, .default = 21.225),
#> .default = case_when(wt <= 4.5475 ~ 17.15, .default = 10.4)) +
#> case_when(cyl <= 5 ~ case_when(disp <= 101.55 ~ 31.65, .default = 23.3),
#> .default = case_when(cyl <= 7 ~ 20.2666666666667, .default = 15.3538461538462)) +
#> case_when(cyl <= 5 ~ case_when(wt <= 1.885 ~ 31.5666666666667,
#> .default = 23.9714285714286), .default = case_when(cyl <=
#> 7 ~ 19.8, .default = 14.91875)))/5From there, the Tidy Eval formula can be used anywhere where it can
be operated. tidypredict provides three paths:
dplyr,
mutate(iris, !! tidypredict_fit(model))tidypredict_to_column(model) to a piped command
settidypredict_to_sql(model) to retrieve the SQL
statementtidypredict also supports ranger model
objects fitted via the parsnip package.
library(parsnip)
parsnip_model <- rand_forest(mode = "regression", trees = 5) %>%
set_engine("ranger", max.depth = 2) %>%
fit(mpg ~ ., data = mtcars)
tidypredict_fit(parsnip_model)
#> (case_when(gear <= 3.5 ~ case_when(disp <= 197.95 ~ 21.5, .default = 15.42),
#> .default = case_when(drat <= 4 ~ 23.4444444444444, .default = 27.9833333333333)) +
#> case_when(hp <= 131.5 ~ case_when(wt <= 2.2775 ~ 30.5, .default = 21.1125),
#> .default = case_when(drat <= 3.035 ~ 10.4, .default = 16.8833333333333)) +
#> case_when(cyl <= 5 ~ case_when(hp <= 78.5 ~ 31.15, .default = 26.1285714285714),
#> .default = case_when(disp <= 266.9 ~ 20.2, .default = 15.2583333333333)) +
#> case_when(vs <= 0.5 ~ case_when(drat <= 4.325 ~ 16.7, .default = 26),
#> .default = case_when(wt <= 2.26 ~ 32.2333333333333, .default = 20.6375)) +
#> case_when(disp <= 120.65 ~ case_when(vs <= 0.5 ~ 26, .default = 31.2777777777778),
#> .default = case_when(wt <= 3.3125 ~ 21.4555555555556,
#> .default = 16.4307692307692)))/5