This vignette demonstrates how to use the SMMAL package to estimate the Average Treatment Effect (ATE) using semi-supervised machine learning. We provide an example dataset and walk through the required input format and function usage.
Sample data contain 1000 observations with 60% of Y and A missing at random. Y is the outcome. A is the treatment indicator. X are the covariates. S are the surrogates.
For the sample data, missingness occurs at random and is encoded as NA. This package can handle datasets with a high proportion of missing values, but it requires a sufficiently large sample size to ensure that each fold in cross-validation contains at least 20 labeled observations.
Input file S and X needs to be data frame, even if they are vectors.
Users can choose which model to use for the nuisance functions by setting the cf_model parameter. If no cf_model is indicated, the default value is “bspline”.
After cross-validation and prediction, the best-performing model is selected based on the lowest cross-entropy (log loss). Users can control how many folds are used in cross-validation by setting the nfold parameter. If no nfold is indicated, the default value is 5.
SMMAL_output1 <- SMMAL(Y=Y,A=A,S=S,X=X)
print(SMMAL_output1)
#> $est
#> [1] 0.1048938
#>
#> $se
#> [1] 0.04907399
Other options for cf_model are “xgboost”
SMMAL_output2 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "xgboost")
print(SMMAL_output2)
#> $est
#> [1] 0.09174881
#>
#> $se
#> [1] 0.05081966
or “random forest”
SMMAL_output3 <- SMMAL(Y=Y,A=A,S=S,X=X,cf_model= "randomforest")
print(SMMAL_output3)
#> $est
#> [1] 0.09918991
#>
#> $se
#> [1] 0.05248152
or “glm”
Users may customize the feature‐selection or penalization strategy by supplying their own function through the custom_model_fun argument. To do so, pass a function that meets these requirements:
(X, Y, foldid_labelled, sub_set, labeled_indices, and nfold are used internally by SMMAL to partition and fit the data.)
(log_loss is a function for computing cross‐entropy (log‐loss). Your function should call log_loss(true_labels, predicted_probs) to evaluate each tuning parameter.)
Below is an example showing how to plug in the packaged SMMAL_ada_lasso() as custom_model_fun. In practice, you could substitute any function with the same signature and return type:
SMMAL_output5 <- SMMAL(Y=Y,A=A,S=S,X=X, custom_model_fun = SMMAL_ada_lasso)
print(SMMAL_output5)
#> $est
#> [1] 0.1280551
#>
#> $se
#> [1] 0.05729187
SMMAL_ada_lasso
#> function (X, Y, X_full, foldid, foldid_labelled, sub_set, labeled_indices,
#> nfold, log_loss)
#> {
#> fold_predictions <- NULL
#> param_grid <- param_fun()
#> ridge_list <- param_grid$ridge
#> lambda_list <- param_grid$lambda
#> fold_predictions <- vector("list", length(ridge_list) * length(lambda_list))
#> for (r in seq_along(ridge_list)) {
#> ridge_val <- ridge_list[[r]]
#> for (i in seq_along(lambda_list)) {
#> lambda <- lambda_list[[i]]
#> ridge_fit_all <- glmnet::glmnet(X, Y, lambda = ridge_val,
#> alpha = 0, family = "binomial")
#> ridge_coef <- as.numeric(coef(ridge_fit_all))[-1]
#> penalty_factors <- 1/(abs(ridge_coef) + 1e-04)
#> all_preds_matrix <- matrix(NA, nrow = length(foldid),
#> ncol = 1)
#> for (ifold in seq_len(nfold)) {
#> trainpos <- which((foldid_labelled != ifold) &
#> sub_set[labeled_indices])
#> testpos <- which(foldid == ifold)
#> X_train <- as.matrix(X[trainpos, , drop = FALSE])
#> Y_train <- as.numeric(Y[trainpos])
#> X_test <- as.matrix(X_full[testpos, , drop = FALSE])
#> valid_idx <- which(!is.na(Y_train))
#> X_train <- X_train[valid_idx, , drop = FALSE]
#> Y_train <- Y_train[valid_idx]
#> fit <- glmnet::glmnet(X_train, Y_train, lambda = lambda,
#> alpha = 1, family = "binomial", penalty.factor = penalty_factors,
#> maxit = 1e+06)
#> preds <- predict(fit, newx = X_test, type = "response")
#> all_preds_matrix[testpos, ] <- preds
#> }
#> fold_predictions[[(r - 1) * length(lambda_list) +
#> i]] <- all_preds_matrix
#> }
#> }
#> return(fold_predictions)
#> }
#> <bytecode: 0x00000253d775deb8>
#> <environment: namespace:SMMAL>
str(data_loaded)
#> List of 9
#> $ X : num [1:400, 1:50] 0.672 0.056 0.108 0.142 0.371 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : chr [1:400] "1" "2" "3" "4" ...
#> .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#> $ Y : int [1:400] 0 0 0 0 1 0 0 1 1 0 ...
#> $ X_full : num [1:1000, 1:50] 0.672 0.056 0.108 0.142 0.732 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : NULL
#> .. ..$ : chr [1:50] "X" "V2" "V3" "V4" ...
#> $ foldid : num [1:1000] 1 4 5 3 2 3 1 4 4 5 ...
#> $ foldid_labelled: num [1:400] 1 4 5 3 1 4 5 2 2 3 ...
#> $ sub_set : logi [1:400] TRUE TRUE TRUE TRUE TRUE TRUE ...
#> $ labeled_indices: int [1:400] 1 2 3 4 7 9 12 15 17 18 ...
#> $ nfold : num 5
#> $ log_loss :function (y_true, y_pred)
#> ..- attr(*, "srcref")= 'srcref' int [1:8] 1 15 5 3 15 3 1 5
#> .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x00000253d568a9d0>
Input: X, Y, foldid_labelled, sub_set, labeled_indices, nfold, log_loss
X: The full matrix of predictors for labelled observations
Y: Outcome vector of length n, binary, may contain NA for unlabeled rows.
X_full: The full matrix of predictors for all observations.
foldid: A vector assigning each observation (labelled or unlabelled) to a fold.
foldid_labelled: Integer vector assigning labeled rows to CV folds (1 to nfold); NA for unlabeled.
sub_set: Logical or integer vector indicating rows included in supervised CV.
labeled_indices: Indices of labeled observations (where Y is not missing).
nfold: Number of cross-validation folds (e.g., 5 or 10).
log_loss: Function that computes log-loss: log_loss(true_labels, pred_probs) returns a single numeric.
Output:fold_predictions
When you use SMMAL_ada_lasso() as a custom_model_fun, it returns a list of numeric vectors where each element is a numeric vector of length equal to the total number of observations, containing the cross-validated predicted probabilities for the corresponding ridge value.
Below is a sample run & output of SMMAL_ada_lasso
SMMAL_fold_predictions <-SMMAL_ada_lasso(
X = data_loaded$X,
Y = data_loaded$Y,
X_full = data_loaded$X_full,
foldid = data_loaded$foldid,
foldid_labelled = data_loaded$foldid_labelled,
sub_set = data_loaded$sub_set,
labeled_indices = data_loaded$labeled_indices,
nfold = data_loaded$nfold,
log_loss = data_loaded$log_loss
)
str(SMMAL_fold_predictions)
#> List of 100
#> $ : num [1:1000, 1] 0.485 0.486 0.439 0.359 0.442 ...
#> $ : num [1:1000, 1] 0.482 0.53 0.421 0.369 0.426 ...
#> $ : num [1:1000, 1] 0.485 0.548 0.41 0.38 0.425 ...
#> $ : num [1:1000, 1] 0.496 0.566 0.408 0.393 0.434 ...
#> $ : num [1:1000, 1] 0.51 0.583 0.4 0.402 0.439 ...
#> $ : num [1:1000, 1] 0.525 0.602 0.395 0.409 0.441 ...
#> $ : num [1:1000, 1] 0.54 0.61 0.399 0.415 0.447 ...
#> $ : num [1:1000, 1] 0.559 0.611 0.409 0.423 0.455 ...
#> $ : num [1:1000, 1] 0.58 0.607 0.42 0.431 0.463 ...
#> $ : num [1:1000, 1] 0.588 0.603 0.438 0.434 0.474 ...
#> $ : num [1:1000, 1] 0.598 0.602 0.46 0.443 0.484 ...
#> $ : num [1:1000, 1] 0.61 0.603 0.485 0.461 0.482 ...
#> $ : num [1:1000, 1] 0.623 0.603 0.514 0.48 0.483 ...
#> $ : num [1:1000, 1] 0.637 0.601 0.525 0.498 0.494 ...
#> $ : num [1:1000, 1] 0.647 0.594 0.538 0.517 0.508 ...
#> $ : num [1:1000, 1] 0.648 0.584 0.552 0.534 0.519 ...
#> $ : num [1:1000, 1] 0.642 0.578 0.563 0.543 0.526 ...
#> $ : num [1:1000, 1] 0.628 0.577 0.567 0.562 0.533 ...
#> $ : num [1:1000, 1] 0.621 0.575 0.572 0.581 0.54 ...
#> $ : num [1:1000, 1] 0.613 0.575 0.574 0.581 0.549 ...
#> $ : num [1:1000, 1] 0.475 0.404 0.469 0.353 0.455 ...
#> $ : num [1:1000, 1] 0.485 0.443 0.448 0.359 0.451 ...
#> $ : num [1:1000, 1] 0.49 0.484 0.429 0.369 0.45 ...
#> $ : num [1:1000, 1] 0.494 0.526 0.412 0.381 0.436 ...
#> $ : num [1:1000, 1] 0.499 0.557 0.406 0.394 0.431 ...
#> $ : num [1:1000, 1] 0.514 0.576 0.4 0.404 0.437 ...
#> $ : num [1:1000, 1] 0.529 0.595 0.391 0.41 0.441 ...
#> $ : num [1:1000, 1] 0.545 0.612 0.395 0.412 0.447 ...
#> $ : num [1:1000, 1] 0.563 0.614 0.404 0.42 0.454 ...
#> $ : num [1:1000, 1] 0.581 0.609 0.415 0.428 0.462 ...
#> $ : num [1:1000, 1] 0.591 0.605 0.433 0.433 0.474 ...
#> $ : num [1:1000, 1] 0.601 0.603 0.455 0.442 0.484 ...
#> $ : num [1:1000, 1] 0.613 0.604 0.476 0.46 0.482 ...
#> $ : num [1:1000, 1] 0.627 0.604 0.509 0.482 0.48 ...
#> $ : num [1:1000, 1] 0.642 0.603 0.523 0.5 0.489 ...
#> $ : num [1:1000, 1] 0.651 0.596 0.536 0.519 0.504 ...
#> $ : num [1:1000, 1] 0.652 0.586 0.549 0.537 0.518 ...
#> $ : num [1:1000, 1] 0.65 0.579 0.561 0.541 0.525 ...
#> $ : num [1:1000, 1] 0.637 0.577 0.566 0.559 0.532 ...
#> $ : num [1:1000, 1] 0.622 0.575 0.571 0.579 0.539 ...
#> $ : num [1:1000, 1] 0.48 0.455 0.436 0.384 0.468 ...
#> $ : num [1:1000, 1] 0.502 0.493 0.409 0.397 0.459 ...
#> $ : num [1:1000, 1] 0.525 0.535 0.396 0.407 0.452 ...
#> $ : num [1:1000, 1] 0.539 0.578 0.386 0.413 0.447 ...
#> $ : num [1:1000, 1] 0.557 0.614 0.385 0.416 0.445 ...
#> $ : num [1:1000, 1] 0.575 0.62 0.391 0.419 0.45 ...
#> $ : num [1:1000, 1] 0.596 0.616 0.402 0.423 0.458 ...
#> $ : num [1:1000, 1] 0.6 0.611 0.418 0.431 0.471 ...
#> $ : num [1:1000, 1] 0.607 0.608 0.442 0.441 0.48 ...
#> $ : num [1:1000, 1] 0.619 0.607 0.459 0.456 0.478 ...
#> $ : num [1:1000, 1] 0.634 0.607 0.483 0.476 0.475 ...
#> $ : num [1:1000, 1] 0.649 0.606 0.514 0.502 0.477 ...
#> $ : num [1:1000, 1] 0.657 0.6 0.532 0.523 0.491 ...
#> $ : num [1:1000, 1] 0.659 0.59 0.545 0.54 0.507 ...
#> $ : num [1:1000, 1] 0.661 0.582 0.559 0.542 0.521 ...
#> $ : num [1:1000, 1] 0.651 0.579 0.565 0.556 0.53 ...
#> $ : num [1:1000, 1] 0.638 0.576 0.569 0.573 0.537 ...
#> $ : num [1:1000, 1] 0.622 0.575 0.574 0.581 0.545 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.554 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.564 0.575 0.376 0.42 0.468 ...
#> $ : num [1:1000, 1] 0.598 0.606 0.374 0.422 0.466 ...
#> $ : num [1:1000, 1] 0.615 0.616 0.381 0.431 0.472 ...
#> $ : num [1:1000, 1] 0.62 0.622 0.393 0.436 0.471 ...
#> $ : num [1:1000, 1] 0.623 0.619 0.419 0.441 0.475 ...
#> $ : num [1:1000, 1] 0.63 0.614 0.434 0.454 0.472 ...
#> $ : num [1:1000, 1] 0.645 0.611 0.451 0.471 0.468 ...
#> $ : num [1:1000, 1] 0.662 0.611 0.479 0.497 0.465 ...
#> $ : num [1:1000, 1] 0.664 0.602 0.514 0.527 0.479 ...
#> $ : num [1:1000, 1] 0.666 0.593 0.538 0.538 0.492 ...
#> $ : num [1:1000, 1] 0.669 0.584 0.553 0.543 0.506 ...
#> $ : num [1:1000, 1] 0.662 0.58 0.561 0.555 0.52 ...
#> $ : num [1:1000, 1] 0.652 0.578 0.567 0.572 0.532 ...
#> $ : num [1:1000, 1] 0.64 0.576 0.571 0.581 0.542 ...
#> $ : num [1:1000, 1] 0.624 0.575 0.574 0.581 0.551 ...
#> $ : num [1:1000, 1] 0.61 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.582 0.556 0.393 0.429 0.493 ...
#> $ : num [1:1000, 1] 0.595 0.575 0.377 0.433 0.491 ...
#> $ : num [1:1000, 1] 0.604 0.591 0.37 0.438 0.488 ...
#> $ : num [1:1000, 1] 0.623 0.604 0.37 0.442 0.486 ...
#> $ : num [1:1000, 1] 0.639 0.61 0.388 0.447 0.48 ...
#> $ : num [1:1000, 1] 0.644 0.62 0.405 0.453 0.472 ...
#> $ : num [1:1000, 1] 0.652 0.623 0.419 0.463 0.463 ...
#> $ : num [1:1000, 1] 0.666 0.619 0.436 0.485 0.458 ...
#> $ : num [1:1000, 1] 0.668 0.607 0.465 0.513 0.463 ...
#> $ : num [1:1000, 1] 0.67 0.598 0.497 0.527 0.477 ...
#> $ : num [1:1000, 1] 0.673 0.588 0.533 0.537 0.49 ...
#> $ : num [1:1000, 1] 0.673 0.581 0.555 0.546 0.503 ...
#> $ : num [1:1000, 1] 0.664 0.58 0.561 0.561 0.518 ...
#> $ : num [1:1000, 1] 0.653 0.578 0.567 0.578 0.53 ...
#> $ : num [1:1000, 1] 0.642 0.577 0.571 0.582 0.543 ...
#> $ : num [1:1000, 1] 0.628 0.576 0.574 0.581 0.553 ...
#> $ : num [1:1000, 1] 0.615 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> $ : num [1:1000, 1] 0.609 0.575 0.574 0.581 0.559 ...
#> [list output truncated]