## ----setup, include = FALSE--------------------------------------------------- knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 7, fig.height = 5 ) set.seed(2025) ## ----install, eval = FALSE---------------------------------------------------- # # From the package source directory: # devtools::install("swjm") # # # Or from a built tarball: # install.packages("swjm_0.1.0.tar.gz", repos = NULL, type = "source") ## ----library------------------------------------------------------------------ library(swjm) ## ----gen-jfm------------------------------------------------------------------ set.seed(123) dat_jfm <- generate_data(n = 500, p = 10, scenario = 1, model = "jfm") Data_jfm <- dat_jfm$data # Preview head(Data_jfm[, 1:8]) ## ----true-jfm-alpha----------------------------------------------------------- round(dat_jfm$alpha_true, 2) ## ----true-jfm-beta------------------------------------------------------------ round(dat_jfm$beta_true, 2) ## ----gen-jscm----------------------------------------------------------------- set.seed(456) dat_jscm <- generate_data(n = 500, p = 10, scenario = 1, model = "jscm") Data_jscm <- dat_jscm$data ## ----fit-jfm------------------------------------------------------------------ fit_jfm <- stagewise_fit( Data_jfm, model = "jfm", penalty = "coop" # cooperative lasso ) fit_jfm ## ----path-explore------------------------------------------------------------- p <- 10 k <- ncol(fit_jfm$alpha) active_final <- which(fit_jfm$alpha[, k] != 0 | fit_jfm$beta[, k] != 0) ## ----path-explore-alpha------------------------------------------------------- round(fit_jfm$alpha[, k], 4) ## ----path-summary------------------------------------------------------------- summary(fit_jfm) ## ----plot-path, fig.height = 8------------------------------------------------ plot(fit_jfm) ## ----plot-path-re, fig.height = 5--------------------------------------------- plot(fit_jfm, which = "readmission") ## ----cv-jfm-prep-------------------------------------------------------------- lambda_path <- fit_jfm$lambda dec_idx <- swjm:::extract_decreasing_indices(lambda_path) lambda_seq <- lambda_path[dec_idx] ## ----cv-jfm, cache = TRUE----------------------------------------------------- set.seed(1) cv_jfm <- cv_stagewise( Data_jfm, model = "jfm", penalty = "coop", lambda_seq = lambda_seq, K = 3L ) cv_jfm ## ----plot-cv------------------------------------------------------------------ plot(cv_jfm) ## ----coef-jfm-alpha----------------------------------------------------------- round(cv_jfm$alpha[cv_jfm$alpha != 0], 4) ## ----coef-jfm-beta------------------------------------------------------------ round(cv_jfm$beta[cv_jfm$beta != 0], 4) ## ----summary-jfm-------------------------------------------------------------- summary(cv_jfm) ## ----coef-vec----------------------------------------------------------------- theta_best <- coef(cv_jfm) length(theta_best) # 2p ## ----baseline----------------------------------------------------------------- bh <- baseline_hazard(cv_jfm, times = c(0.5, 1.0, 2.0, 4.0, 6.0)) print(bh) ## ----baseline-re-------------------------------------------------------------- bh_re <- baseline_hazard(cv_jfm, times = seq(0, 5, by = 0.5), which = "readmission") head(bh_re) ## ----predict-jfm, fig.height = 7---------------------------------------------- set.seed(7) newz <- matrix(rnorm(30), nrow = 12, ncol = 10) rownames(newz) <- paste0("Patient_", 1:12) colnames(newz) <- paste0("x", 1:10) pred <- predict(cv_jfm, newdata = newz) pred ## ----pred-survival------------------------------------------------------------ # Survival probabilities for all subjects at first few time points round(pred$S_re[, 1:5], 3) ## ----plot-pred, fig.height = 8------------------------------------------------ plot(pred, which_subject = 7) ## ----plot-pred-re, fig.height = 5--------------------------------------------- plot(pred, which_subject = 2, which_process = "readmission") ## ----lasso, eval = FALSE------------------------------------------------------ # fit_lasso <- stagewise_fit(Data_jfm, model = "jfm", penalty = "lasso") # set.seed(2) # cv_lasso <- cv_stagewise(Data_jfm, model = "jfm", penalty = "lasso", K = 3L) # summary(cv_lasso) ## ----group, eval = FALSE------------------------------------------------------ # fit_group <- stagewise_fit(Data_jfm, model = "jfm", penalty = "group") # set.seed(3) # cv_group <- cv_stagewise(Data_jfm, model = "jfm", penalty = "group", K = 3L) # summary(cv_group) ## ----fit-jscm----------------------------------------------------------------- set.seed(456) dat_jscm <- generate_data(n = 500, p = 10, scenario = 1, model = "jscm") Data_jscm <- dat_jscm$data fit_jscm <- stagewise_fit(Data_jscm, model = "jscm", penalty = "coop") fit_jscm ## ----cv-jscm, cache = TRUE---------------------------------------------------- lambda_path_jscm <- fit_jscm$lambda dec_idx_jscm <- swjm:::extract_decreasing_indices(lambda_path_jscm) lambda_seq_jscm <- lambda_path_jscm[dec_idx_jscm] set.seed(10) cv_jscm <- cv_stagewise( Data_jscm, model = "jscm", penalty = "coop", lambda_seq = lambda_seq_jscm, K = 3L ) cv_jscm ## ----plot-cv-jscm------------------------------------------------------------- plot(cv_jscm) ## ----summary-jscm------------------------------------------------------------- summary(cv_jscm) ## ----baseline-jscm------------------------------------------------------------ bh_jscm <- baseline_hazard(cv_jscm, times = c(0.5, 1.0, 2.0, 3.0, 4.0)) print(bh_jscm) ## ----predict-jscm------------------------------------------------------------- set.seed(7) newz_jscm <- matrix(runif(30, -1, 1), nrow = 3, ncol = 10) rownames(newz_jscm) <- paste0("Patient_", 1:3) pred_jscm <- predict(cv_jscm, newdata = newz_jscm) pred_jscm ## ----predict-jscm-accel------------------------------------------------------- round(pred_jscm$time_accel_re, 3) ## ----plot-pred-jscm, fig.height = 8------------------------------------------- plot(pred_jscm, which_subject = 1) ## ----interpret---------------------------------------------------------------- a <- cv_jfm$alpha b <- cv_jfm$beta nz_a <- which(a != 0) nz_b <- which(b != 0) shared <- intersect(nz_a, nz_b) same_sign <- if (length(shared) > 0) shared[sign(a[shared]) == sign(b[shared])] else integer(0) opp_sign <- if (length(shared) > 0) shared[sign(a[shared]) != sign(b[shared])] else integer(0) ## ----contrib-example---------------------------------------------------------- c1_re <- pred$contrib_re[1, ] c1_de <- pred$contrib_de[1, ] ## ----contrib-re--------------------------------------------------------------- round(c1_re[c1_re != 0], 4) ## ----contrib-de--------------------------------------------------------------- round(c1_de[c1_de != 0], 4) ## ----coef-compare------------------------------------------------------------- p <- 10 show_jfm <- sort(which(dat_jfm$alpha_true != 0 | cv_jfm$alpha != 0 | dat_jfm$beta_true != 0 | cv_jfm$beta != 0)) coef_df <- data.frame( variable = paste0("x", show_jfm), true_alpha = round(dat_jfm$alpha_true[show_jfm], 3), est_alpha = round(cv_jfm$alpha[show_jfm], 3), true_beta = round(dat_jfm$beta_true[show_jfm], 3), est_beta = round(cv_jfm$beta[show_jfm], 3) ) colnames(coef_df) <- c("variable", "alpha_true", "alpha_est", "beta_true", "beta_est") print(coef_df, row.names = FALSE) ## ----coef-compare-jscm-------------------------------------------------------- show_jscm <- sort(which(dat_jscm$alpha_true != 0 | cv_jscm$alpha != 0 | dat_jscm$beta_true != 0 | cv_jscm$beta != 0)) coef_jscm <- data.frame( variable = paste0("x", show_jscm), true_alpha = round(dat_jscm$alpha_true[show_jscm], 3), est_alpha = round(cv_jscm$alpha[show_jscm], 3), true_beta = round(dat_jscm$beta_true[show_jscm], 3), est_beta = round(cv_jscm$beta[show_jscm], 3) ) colnames(coef_jscm) <- c("variable", "alpha_true", "alpha_est", "beta_true", "beta_est") print(coef_jscm, row.names = FALSE) ## ----auc-prep----------------------------------------------------------------- # Construct competing-risk dataset: # Keep first readmission (event==1 & t.start==0) + death/censor (event==0). # Status: 1 = first readmission, 2 = death, 0 = censored. .cr_data <- function(Data) { d3 <- Data[Data$event == 0 | (Data$event == 1 & Data$t.start == 0), ] d3 <- d3[order(d3$id, d3$t.start, d3$t.stop), ] status <- ifelse(d3$event == 1 & d3$status == 0, 1L, ifelse(d3$event == 0 & d3$status == 0, 0L, 2L)) list(data = d3, status = status) } cr_jfm <- .cr_data(Data_jfm) cr_jscm <- .cr_data(Data_jscm) # Baseline covariates (one row per subject) Z_jfm <- as.matrix(Data_jfm[!duplicated(Data_jfm$id), paste0("x", 1:p)]) Z_jscm <- as.matrix(Data_jscm[!duplicated(Data_jscm$id), paste0("x", 1:p)]) # Markers expanded to row level: alpha^T z for readmission, beta^T z for death M_re_jfm <- drop(Z_jfm %*% cv_jfm$alpha)[cr_jfm$data$id] M_de_jfm <- drop(Z_jfm %*% cv_jfm$beta)[cr_jfm$data$id] M_re_jscm <- drop(Z_jscm %*% cv_jscm$alpha)[cr_jscm$data$id] M_de_jscm <- drop(Z_jscm %*% cv_jscm$beta)[cr_jscm$data$id] ## ----auc, cache = TRUE-------------------------------------------------------- if (!requireNamespace("timeROC", quietly = TRUE)) install.packages("timeROC") library(survival) library(timeROC) # Evaluation grid: 20 points spanning the 10th-85th percentile of event times .tgrid <- function(t_vec, status, n = 20) { t_ev <- t_vec[status > 0] seq(quantile(t_ev, 0.10), quantile(t_ev, 0.85), length.out = n) } t_jfm <- .tgrid(cr_jfm$data$t.stop, cr_jfm$status) t_jscm <- .tgrid(cr_jscm$data$t.stop, cr_jscm$status) # Readmission AUC: alpha^T z marker, cause = 1 roc_re_jfm <- timeROC(T = cr_jfm$data$t.stop, delta = cr_jfm$status, marker = M_re_jfm, cause = 1, weighting = "marginal", times = t_jfm, ROC = FALSE, iid = FALSE) roc_re_jscm <- timeROC(T = cr_jscm$data$t.stop, delta = cr_jscm$status, marker = M_re_jscm, cause = 1, weighting = "marginal", times = t_jscm, ROC = FALSE, iid = FALSE) # Death AUC: beta^T z marker, cause = 2 roc_de_jfm <- timeROC(T = cr_jfm$data$t.stop, delta = cr_jfm$status, marker = M_de_jfm, cause = 2, weighting = "marginal", times = t_jfm, ROC = FALSE, iid = FALSE) roc_de_jscm <- timeROC(T = cr_jscm$data$t.stop, delta = cr_jscm$status, marker = M_de_jscm, cause = 2, weighting = "marginal", times = t_jscm, ROC = FALSE, iid = FALSE) ## ----auc-plot, fig.height = 5, fig.width = 8---------------------------------- .get_auc <- function(roc, cause) { auc <- roc[[paste0("AUC_", cause)]] if (is.null(auc)) auc <- roc$AUC if (is.null(auc) || !is.numeric(auc)) return(rep(NA_real_, length(roc$times))) if (length(auc) == length(roc$times) + 1) auc <- auc[-1] as.numeric(auc) } old_par <- par(mfrow = c(1, 2), mar = c(4.5, 4, 3, 1)) plot(t_jfm, .get_auc(roc_re_jfm, 1), type = "l", lwd = 2, col = "steelblue", xlab = "Time", ylab = "AUC(t)", main = "JFM", ylim = c(0.4, 1)) lines(t_jfm, .get_auc(roc_de_jfm, 2), lwd = 2, col = "tomato", lty = 2) abline(h = 0.5, lty = 3, col = "grey60") legend("bottomleft", c("Readmission", "Death"), col = c("steelblue", "tomato"), lwd = 2, lty = c(1, 2), bty = "n", cex = 0.85) plot(t_jscm, .get_auc(roc_re_jscm, 1), type = "l", lwd = 2, col = "steelblue", xlab = "Time", ylab = "AUC(t)", main = "JSCM", ylim = c(0.4, 1)) lines(t_jscm, .get_auc(roc_de_jscm, 2), lwd = 2, col = "tomato", lty = 2) abline(h = 0.5, lty = 3, col = "grey60") legend("bottomleft", c("Readmission", "Death"), col = c("steelblue", "tomato"), lwd = 2, lty = c(1, 2), bty = "n", cex = 0.85) par(old_par)