SMMAL_vignette

Introduction

This vignette demonstrates how to use the SMMAL package to estimate the Average Treatment Effect (ATE) using semi-supervised machine learning. We provide an example dataset and walk through the required input format and function usage.

Import Sample data.

Sample data contain 1000 observations with 60% of Y and A missing at random. Y is the outcome. A is the treatment indicator. X are the covariates. S are the surrogates.

For the sample data, missingness occurs at random and is encoded as NA. This package can handle datasets with a high proportion of missing values, but it requires a sufficiently large sample size to ensure that each fold in cross-validation contains at least 20 labeled observations.

library(SMMAL)

file_path <- system.file("extdata", "sample_data_withmissing.rds", package = "SMMAL")
dat <- readRDS(file_path)


file_path2 <- system.file("extdata", "semi_supervised_data.rds", package = "SMMAL")
data_loaded <- readRDS(file_path2)

Prepare Inputs

Input file S and X needs to be data frame, even if they are vectors.

  # Y and A are numeric vector 
  Y <- dat$Y
  A <- dat$A
  
  # S and X needs to be data frame
  S <- data.frame(dat$S)
  X <- data.frame(dat$X)

Estimate ATE with SMMAL & Output

Users can choose which model to use for the nuisance functions by setting the cf_model parameter. If no cf_model is indicated, the default value is “bspline”.

After cross-validation and prediction, the best-performing model is selected based on the lowest cross-entropy (log loss). Users can control how many folds are used in cross-validation by setting the nfold parameter. If no nfold is indicated, the default value is 5.

 SMMAL_output1 <- SMMAL(Y=Y,A=A,S=S,X=X)
 print(SMMAL_output1)
#> $est
#> [1] 0.1048938
#> 
#> $se
#> [1] 0.04907399

Other options for cf_model are “xgboost”

SMMAL_output2 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "xgboost")
print(SMMAL_output2)
#> $est
#> [1] 0.09174881
#> 
#> $se
#> [1] 0.05081966

or “random forest”

SMMAL_output3 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "randomforest")
print(SMMAL_output3)
#> $est
#> [1] 0.09918991
#> 
#> $se
#> [1] 0.05248152

or “glm”

 SMMAL_output4 <- SMMAL(Y=Y,A=A,S=S,X=X, cf_model= "glm")
 print(SMMAL_output4)
#> $est
#> [1] 0.1341408
#> 
#> $se
#> [1] 0.05347756

Using Your Own custom_model_fun

Users may customize the feature‐selection or penalization strategy by supplying their own function through the custom_model_fun argument. To do so, pass a function that meets these requirements:

  1. Function Signature It must accept exactly these arguments (in this order): X, Y, foldid_labelled, sub_set, labeled_indices, nfold, log_loss

(X, Y, foldid_labelled, sub_set, labeled_indices, and nfold are used internally by SMMAL to partition and fit the data.)

(log_loss is a function for computing cross‐entropy (log‐loss). Your function should call log_loss(true_labels, predicted_probs) to evaluate each tuning parameter.)

  1. Return Value It must return a list of length equal to the number of “ridge” penalty values defined in param_fun(). Each element of that list should be a numeric vector of length n containing out‐of‐fold predicted probabilities for all observations—i.e., it should stack together predictions from every held‐out fold (no NA values, except where Y is genuinely missing).

Below is an example showing how to plug in the packaged SMMAL_ada_lasso() as custom_model_fun. In practice, you could substitute any function with the same signature and return type:

 SMMAL_output5 <- SMMAL(Y=Y,A=A,S=S,X=X, custom_model_fun = SMMAL_ada_lasso)
 print(SMMAL_output5)
#> $est
#> [1] 0.1280551
#> 
#> $se
#> [1] 0.05729187

Understanding SMMAL_ada_lasso

SMMAL_ada_lasso
#> function (X, Y, X_full, foldid, foldid_labelled, sub_set, labeled_indices, 
#>     nfold, log_loss) 
#> {
#>     fold_predictions <- NULL
#>     param_grid <- param_fun()
#>     ridge_list <- param_grid$ridge
#>     lambda_list <- param_grid$lambda
#>     fold_predictions <- vector("list", length(ridge_list) * length(lambda_list))
#>     for (r in seq_along(ridge_list)) {
#>         ridge_val <- ridge_list[[r]]
#>         for (i in seq_along(lambda_list)) {
#>             lambda <- lambda_list[[i]]
#>             ridge_fit_all <- glmnet::glmnet(X, Y, lambda = ridge_val, 
#>                 alpha = 0, family = "binomial")
#>             ridge_coef <- as.numeric(coef(ridge_fit_all))[-1]
#>             penalty_factors <- 1/(abs(ridge_coef) + 1e-04)
#>             all_preds_matrix <- matrix(NA, nrow = length(foldid), 
#>                 ncol = 1)
#>             for (ifold in seq_len(nfold)) {
#>                 trainpos <- which((foldid_labelled != ifold) & 
#>                   sub_set[labeled_indices])
#>                 testpos <- which(foldid == ifold)
#>                 X_train <- as.matrix(X[trainpos, , drop = FALSE])
#>                 Y_train <- as.numeric(Y[trainpos])
#>                 X_test <- as.matrix(X_full[testpos, , drop = FALSE])
#>                 valid_idx <- which(!is.na(Y_train))
#>                 X_train <- X_train[valid_idx, , drop = FALSE]
#>                 Y_train <- Y_train[valid_idx]
#>                 fit <- glmnet::glmnet(X_train, Y_train, lambda = lambda, 
#>                   alpha = 1, family = "binomial", penalty.factor = penalty_factors, 
#>                   maxit = 1e+06)
#>                 preds <- predict(fit, newx = X_test, type = "response")
#>                 all_preds_matrix[testpos, ] <- preds
#>             }
#>             fold_predictions[[(r - 1) * length(lambda_list) + 
#>                 i]] <- all_preds_matrix
#>         }
#>     }
#>     return(fold_predictions)
#> }
#> <bytecode: 0x00000253d775deb8>
#> <environment: namespace:SMMAL>

Input of SMMAL_ada_lasso

str(data_loaded)
#> List of 9
#>  $ X              : num [1:400, 1:50] 0.672 0.056 0.108 0.142 0.371 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : chr [1:400] "1" "2" "3" "4" ...
#>   .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#>  $ Y              : int [1:400] 0 0 0 0 1 0 0 1 1 0 ...
#>  $ X_full         : num [1:1000, 1:50] 0.672 0.056 0.108 0.142 0.732 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#>  $ foldid         : num [1:1000] 1 4 5 3 2 3 1 4 4 5 ...
#>  $ foldid_labelled: num [1:400] 1 4 5 3 1 4 5 2 2 3 ...
#>  $ sub_set        : logi [1:400] TRUE TRUE TRUE TRUE TRUE TRUE ...
#>  $ labeled_indices: int [1:400] 1 2 3 4 7 9 12 15 17 18 ...
#>  $ nfold          : num 5
#>  $ log_loss       :function (y_true, y_pred)  
#>   ..- attr(*, "srcref")= 'srcref' int [1:8] 1 15 5 3 15 3 1 5
#>   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x00000253d568a9d0>

Input: X, Y, foldid_labelled, sub_set, labeled_indices, nfold, log_loss

X: The full matrix of predictors for labelled observations

Y: Outcome vector of length n, binary, may contain NA for unlabeled rows.

X_full: The full matrix of predictors for all observations.

foldid: A vector assigning each observation (labelled or unlabelled) to a fold.

foldid_labelled: Integer vector assigning labeled rows to CV folds (1 to nfold); NA for unlabeled.

sub_set: Logical or integer vector indicating rows included in supervised CV.

labeled_indices: Indices of labeled observations (where Y is not missing).

nfold: Number of cross-validation folds (e.g., 5 or 10).

log_loss: Function that computes log-loss: log_loss(true_labels, pred_probs) returns a single numeric.

Demonstration of how to run SMMAL_ada_lasso & Output of SMMAL_ada_lasso

Output:fold_predictions

When you use SMMAL_ada_lasso() as a custom_model_fun, it returns a list of numeric vectors where each element is a numeric vector of length equal to the total number of observations, containing the cross-validated predicted probabilities for the corresponding ridge value.

Below is a sample run & output of SMMAL_ada_lasso

SMMAL_fold_predictions <-SMMAL_ada_lasso(
  X = data_loaded$X,
  Y = data_loaded$Y,
  X_full = data_loaded$X_full,
  foldid = data_loaded$foldid,
  foldid_labelled = data_loaded$foldid_labelled,
  sub_set = data_loaded$sub_set,
  labeled_indices = data_loaded$labeled_indices,
  nfold = data_loaded$nfold,
  log_loss = data_loaded$log_loss
)

str(SMMAL_fold_predictions)
#> List of 100
#>  $ : num [1:1000, 1] 0.485 0.486 0.439 0.359 0.442 ...
#>  $ : num [1:1000, 1] 0.482 0.53 0.421 0.369 0.426 ...
#>  $ : num [1:1000, 1] 0.485 0.548 0.41 0.38 0.425 ...
#>  $ : num [1:1000, 1] 0.496 0.566 0.408 0.393 0.434 ...
#>  $ : num [1:1000, 1] 0.51 0.583 0.4 0.402 0.439 ...
#>  $ : num [1:1000, 1] 0.525 0.602 0.395 0.409 0.441 ...
#>  $ : num [1:1000, 1] 0.54 0.61 0.399 0.415 0.447 ...
#>  $ : num [1:1000, 1] 0.559 0.611 0.409 0.423 0.455 ...
#>  $ : num [1:1000, 1] 0.58 0.607 0.42 0.431 0.463 ...
#>  $ : num [1:1000, 1] 0.588 0.603 0.438 0.434 0.474 ...
#>  $ : num [1:1000, 1] 0.598 0.602 0.46 0.443 0.484 ...
#>  $ : num [1:1000, 1] 0.61 0.603 0.485 0.461 0.482 ...
#>  $ : num [1:1000, 1] 0.623 0.603 0.514 0.48 0.483 ...
#>  $ : num [1:1000, 1] 0.637 0.601 0.525 0.498 0.494 ...
#>  $ : num [1:1000, 1] 0.647 0.594 0.538 0.517 0.508 ...
#>  $ : num [1:1000, 1] 0.648 0.584 0.552 0.534 0.519 ...
#>  $ : num [1:1000, 1] 0.642 0.578 0.563 0.543 0.526 ...
#>  $ : num [1:1000, 1] 0.628 0.577 0.567 0.562 0.533 ...
#>  $ : num [1:1000, 1] 0.621 0.575 0.572 0.581 0.54 ...
#>  $ : num [1:1000, 1] 0.613 0.575 0.574 0.581 0.549 ...
#>  $ : num [1:1000, 1] 0.475 0.404 0.469 0.353 0.455 ...
#>  $ : num [1:1000, 1] 0.485 0.443 0.448 0.359 0.451 ...
#>  $ : num [1:1000, 1] 0.49 0.484 0.429 0.369 0.45 ...
#>  $ : num [1:1000, 1] 0.494 0.526 0.412 0.381 0.436 ...
#>  $ : num [1:1000, 1] 0.499 0.557 0.406 0.394 0.431 ...
#>  $ : num [1:1000, 1] 0.514 0.576 0.4 0.404 0.437 ...
#>  $ : num [1:1000, 1] 0.529 0.595 0.391 0.41 0.441 ...
#>  $ : num [1:1000, 1] 0.545 0.612 0.395 0.412 0.447 ...
#>  $ : num [1:1000, 1] 0.563 0.614 0.404 0.42 0.454 ...
#>  $ : num [1:1000, 1] 0.581 0.609 0.415 0.428 0.462 ...
#>  $ : num [1:1000, 1] 0.591 0.605 0.433 0.433 0.474 ...
#>  $ : num [1:1000, 1] 0.601 0.603 0.455 0.442 0.484 ...
#>  $ : num [1:1000, 1] 0.613 0.604 0.476 0.46 0.482 ...
#>  $ : num [1:1000, 1] 0.627 0.604 0.509 0.482 0.48 ...
#>  $ : num [1:1000, 1] 0.642 0.603 0.523 0.5 0.489 ...
#>  $ : num [1:1000, 1] 0.651 0.596 0.536 0.519 0.504 ...
#>  $ : num [1:1000, 1] 0.652 0.586 0.549 0.537 0.518 ...
#>  $ : num [1:1000, 1] 0.65 0.579 0.561 0.541 0.525 ...
#>  $ : num [1:1000, 1] 0.637 0.577 0.566 0.559 0.532 ...
#>  $ : num [1:1000, 1] 0.622 0.575 0.571 0.579 0.539 ...
#>  $ : num [1:1000, 1] 0.48 0.455 0.436 0.384 0.468 ...
#>  $ : num [1:1000, 1] 0.502 0.493 0.409 0.397 0.459 ...
#>  $ : num [1:1000, 1] 0.525 0.535 0.396 0.407 0.452 ...
#>  $ : num [1:1000, 1] 0.539 0.578 0.386 0.413 0.447 ...
#>  $ : num [1:1000, 1] 0.557 0.614 0.385 0.416 0.445 ...
#>  $ : num [1:1000, 1] 0.575 0.62 0.391 0.419 0.45 ...
#>  $ : num [1:1000, 1] 0.596 0.616 0.402 0.423 0.458 ...
#>  $ : num [1:1000, 1] 0.6 0.611 0.418 0.431 0.471 ...
#>  $ : num [1:1000, 1] 0.607 0.608 0.442 0.441 0.48 ...
#>  $ : num [1:1000, 1] 0.619 0.607 0.459 0.456 0.478 ...
#>  $ : num [1:1000, 1] 0.634 0.607 0.483 0.476 0.475 ...
#>  $ : num [1:1000, 1] 0.649 0.606 0.514 0.502 0.477 ...
#>  $ : num [1:1000, 1] 0.657 0.6 0.532 0.523 0.491 ...
#>  $ : num [1:1000, 1] 0.659 0.59 0.545 0.54 0.507 ...
#>  $ : num [1:1000, 1] 0.661 0.582 0.559 0.542 0.521 ...
#>  $ : num [1:1000, 1] 0.651 0.579 0.565 0.556 0.53 ...
#>  $ : num [1:1000, 1] 0.638 0.576 0.569 0.573 0.537 ...
#>  $ : num [1:1000, 1] 0.622 0.575 0.574 0.581 0.545 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.554 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.564 0.575 0.376 0.42 0.468 ...
#>  $ : num [1:1000, 1] 0.598 0.606 0.374 0.422 0.466 ...
#>  $ : num [1:1000, 1] 0.615 0.616 0.381 0.431 0.472 ...
#>  $ : num [1:1000, 1] 0.62 0.622 0.393 0.436 0.471 ...
#>  $ : num [1:1000, 1] 0.623 0.619 0.419 0.441 0.475 ...
#>  $ : num [1:1000, 1] 0.63 0.614 0.434 0.454 0.472 ...
#>  $ : num [1:1000, 1] 0.645 0.611 0.451 0.471 0.468 ...
#>  $ : num [1:1000, 1] 0.662 0.611 0.479 0.497 0.465 ...
#>  $ : num [1:1000, 1] 0.664 0.602 0.514 0.527 0.479 ...
#>  $ : num [1:1000, 1] 0.666 0.593 0.538 0.538 0.492 ...
#>  $ : num [1:1000, 1] 0.669 0.584 0.553 0.543 0.506 ...
#>  $ : num [1:1000, 1] 0.662 0.58 0.561 0.555 0.52 ...
#>  $ : num [1:1000, 1] 0.652 0.578 0.567 0.572 0.532 ...
#>  $ : num [1:1000, 1] 0.64 0.576 0.571 0.581 0.542 ...
#>  $ : num [1:1000, 1] 0.624 0.575 0.574 0.581 0.551 ...
#>  $ : num [1:1000, 1] 0.61 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.582 0.556 0.393 0.429 0.493 ...
#>  $ : num [1:1000, 1] 0.595 0.575 0.377 0.433 0.491 ...
#>  $ : num [1:1000, 1] 0.604 0.591 0.37 0.438 0.488 ...
#>  $ : num [1:1000, 1] 0.623 0.604 0.37 0.442 0.486 ...
#>  $ : num [1:1000, 1] 0.639 0.61 0.388 0.447 0.48 ...
#>  $ : num [1:1000, 1] 0.644 0.62 0.405 0.453 0.472 ...
#>  $ : num [1:1000, 1] 0.652 0.623 0.419 0.463 0.463 ...
#>  $ : num [1:1000, 1] 0.666 0.619 0.436 0.485 0.458 ...
#>  $ : num [1:1000, 1] 0.668 0.607 0.465 0.513 0.463 ...
#>  $ : num [1:1000, 1] 0.67 0.598 0.497 0.527 0.477 ...
#>  $ : num [1:1000, 1] 0.673 0.588 0.533 0.537 0.49 ...
#>  $ : num [1:1000, 1] 0.673 0.581 0.555 0.546 0.503 ...
#>  $ : num [1:1000, 1] 0.664 0.58 0.561 0.561 0.518 ...
#>  $ : num [1:1000, 1] 0.653 0.578 0.567 0.578 0.53 ...
#>  $ : num [1:1000, 1] 0.642 0.577 0.571 0.582 0.543 ...
#>  $ : num [1:1000, 1] 0.628 0.576 0.574 0.581 0.553 ...
#>  $ : num [1:1000, 1] 0.615 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>  $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#>   [list output truncated]