| Type: | Package |
| Title: | Fitting Interpretable Neural Additive Models Using Orthogonalization |
| Version: | 1.0.0 |
| Description: | An algorithm for fitting interpretable additive neural networks for identifiable and visualizable feature effects using post hoc orthogonalization. Fit custom neural networks intuitively using established 'R' 'formula' notation, including interaction effects of arbitrary order while preserving identifiability to enable a functional decomposition of the prediction function. For more details see Koehler et al. (2025) <doi:10.1038/s44387-025-00033-7>. |
| License: | MIT + file LICENSE |
| BugReports: | https://github.com/Koehlibert/ONAM_R/issues |
| Depends: | keras3, reticulate |
| Imports: | dplyr, scales, rlang, ggplot2, pROC |
| Suggests: | akima, RColorBrewer, testthat (≥ 3.0.0) |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.3 |
| Config/testthat/edition: | 3 |
| NeedsCompilation: | no |
| Packaged: | 2025-11-06 09:07:04 UTC; koehler |
| Author: | David Köhler |
| Maintainer: | David Köhler <koehler@imbie.uni-bonn.de> |
| Repository: | CRAN |
| Date/Publication: | 2025-11-11 09:50:18 UTC |
Get variance decomposition of orthogonal neural additive model
Description
Get variance decomposition of orthogonal neural additive model
Usage
decompose(object, data = NULL)
Arguments
object |
Either model of class |
data |
Data for which the model is to be evaluated. If |
Value
Returns a named vector of percentage of variance explained by each interaction order.
Examples
# Basic example for a simple ONAM-model
# Create training data
n <- 1000
x1 <- runif(n, -2, 2)
x2 <- runif(n, -2, 2)
y <- sin(x1) + ifelse(x2 > 0, pweibull(x2, shape = 3),
pweibull(-x2, shape = 0.5)) +
x1 * x2
data_train <- cbind(x1, x2, y)
# Define model
model_formula <- y ~ mod1(x1) + mod1(x2) +
mod1(x1, x2)
mod1 <- function(inputs) {
outputs <- inputs %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 8, activation = "linear",
use_bias = TRUE) %>%
layer_dense(units = 1, activation = "linear",
use_bias = TRUE)
keras_model(inputs, outputs)
}
list_of_deep_models <- list(mod1 = mod1)
# Fit model
mod <- onam(model_formula, list_of_deep_models,
data_train, n_ensemble = 1, epochs = 10)
decompose(mod)
Set up conda environment for keras functionality
Description
Helper function to install Keras and packages necessary for package
functionality into a conda environment. Use this function if
keras3::install_keras() does not work, esp. on windows machines.
Usage
install_conda_env(envname = "r-keras", python_version = "python=3.10")
Arguments
envname |
Name for the conda environment to be created. |
python_version |
Python version to be installed in the conda environment. |
Value
No return value, called for side effects
See Also
Fit orthogonal neural additive model
Description
Fits an interpretable neural additive model with post hoc orthogonalization for a given network architecture and user-specified feature sets.
Usage
onam(
formula,
list_of_deep_models,
data,
model = NULL,
prediction_function = NULL,
model_data = NULL,
categorical_features = NULL,
target = "continuous",
n_ensemble = 10,
epochs = 500,
callback = NULL,
progresstext = FALSE,
verbose = 0
)
Arguments
formula |
Formula for model fitting. Specify deep parts with the same
name as |
list_of_deep_models |
List of named models used in |
data |
Data to be fitted |
model |
Prediction model that is to be explained. Output of the model as
returned from |
prediction_function |
Prediction function to be used to generate the
outcome. Only used if |
model_data |
Data used for generating predictions of |
categorical_features |
Vector of feature names of categorical features. |
target |
Target of prediction task. Can be either "continuous" or "binary". For "continuous"(default), an additive model for the prediction of a continuous outcome is fitted. For "binary", a binary classification with sigmoid activation in the last layer is fitted. |
n_ensemble |
Number of orthogonal neural additive model ensembles |
epochs |
Number of epochs to train the model. See
|
callback |
Callback to be called during training. See
|
progresstext |
Show model fitting progress. If |
verbose |
Verbose argument for internal model fitting. used for
debugging. See |
Value
Returns a model object of class onam, containing all ensemble
members, ensemble weights, and main and interaction effect outputs.
Examples
# Basic example for a simple ONAM-model
# Create training data
n <- 1000
x1 <- runif(n, -2, 2)
x2 <- runif(n, -2, 2)
y <- sin(x1) + ifelse(x2 > 0, pweibull(x2, shape = 3),
pweibull(-x2, shape = 0.5)) +
x1 * x2
data_train <- cbind(x1, x2, y)
# Define model
model_formula <- y ~ mod1(x1) + mod1(x2) +
mod1(x1, x2)
mod1 <- function(inputs) {
outputs <- inputs %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 8, activation = "linear",
use_bias = TRUE) %>%
layer_dense(units = 1, activation = "linear",
use_bias = TRUE)
keras_model(inputs, outputs)
}
list_of_deep_models <- list(mod1 = mod1)
# Fit model
mod <- onam(model_formula, list_of_deep_models,
data_train, n_ensemble = 1, epochs = 10)
summary(mod)
Plot Interaction Effect
Description
Plot Interaction Effect
Usage
plot_inter_effect(
object,
feature1,
feature2,
interpolate = FALSE,
custom_colors = "spectral",
n_interpolate = 200
)
Arguments
object |
Either model of class |
feature1, feature2 |
Effects to be plotted. |
interpolate |
If TRUE, values will be interpolated for a smooth plot. If FALSE (default), only observations in the data will be plotted. |
custom_colors |
color palette object for the interaction plot. Default is "spectral", returning a color palette based on the spectral theme. |
n_interpolate |
number of values per coordinate axis to interpolate. Ignored if 'interpolate = FALSE'. |
Value
Returns a 'ggplot2' object of the specified effect interaction
Examples
# Basic example for a simple ONAM-model
# Create training data
n <- 1000
x1 <- runif(n, -2, 2)
x2 <- runif(n, -2, 2)
y <- sin(x1) + ifelse(x2 > 0, pweibull(x2, shape = 3),
pweibull(-x2, shape = 0.5)) +
x1 * x2
data_train <- cbind(x1, x2, y)
# Define model
model_formula <- y ~ mod1(x1) + mod1(x2) +
mod1(x1, x2)
mod1 <- function(inputs) {
outputs <- inputs %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 8, activation = "linear",
use_bias = TRUE) %>%
layer_dense(units = 1, activation = "linear",
use_bias = TRUE)
keras_model(inputs, outputs)
}
list_of_deep_models <- list(mod1 = mod1)
# Fit model
mod <- onam(model_formula, list_of_deep_models,
data_train, n_ensemble = 1, epochs = 10)
plot_inter_effect(mod, "x1", "x2")
Plot Main Effect
Description
Plot Main Effect
Usage
plot_main_effect(object, feature)
Arguments
object |
Either model of class |
feature |
Feature for which the effect is to be plotted, must be present in the model formula. For interaction terms, use plotInteractionEffect |
Value
Returns a ggplot2 object of the specified effect
Examples
# Basic example for a simple ONAM-model
# Create training data
n <- 1000
x1 <- runif(n, -2, 2)
x2 <- runif(n, -2, 2)
y <- sin(x1) + ifelse(x2 > 0, pweibull(x2, shape = 3),
pweibull(-x2, shape = 0.5)) +
x1 * x2
data_train <- cbind(x1, x2, y)
# Define model
model_formula <- y ~ mod1(x1) + mod1(x2) +
mod1(x1, x2)
mod1 <- function(inputs) {
outputs <- inputs %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 8, activation = "linear",
use_bias = TRUE) %>%
layer_dense(units = 1, activation = "linear",
use_bias = TRUE)
keras_model(inputs, outputs)
}
list_of_deep_models <- list(mod1 = mod1)
# Fit model
mod <- onam(model_formula, list_of_deep_models,
data_train, n_ensemble = 1, epochs = 10)
plot_main_effect(mod, "x1")
Evaluate orthogonal neural additive model
Description
Evaluate orthogonal neural additive model
Usage
## S3 method for class 'onam'
predict(object, ..., data = NULL)
Arguments
object |
model of class |
... |
some methods for this generic require additional arguments. None are used in this method. |
data |
Data for which the model is to be evaluated. If NULL (default),
data with which |
Value
Returns a list containing data, model output for each observation in
data and main and interaction effects obtained by the model
Get summary of an onam object
Description
generates a summary of a fitted onam object including
information on ensembling strategy and performance metrics such as
correlation and degree of interpretabiltity
Usage
## S3 method for class 'onam'
summary(object, ...)
## S3 method for class 'summary.onam'
print(x, ...)
Arguments
object |
onam object of class |
... |
further arguments passed to or from other methods. |
x |
object of class |
Details
For examples see example(onam)
Value
Gives summary of the onam object, including model inputs, number
of ensembles, correlation of model output and original outcome variable, and
interpretability metrics i_1 and i_2