--- title: "Introduction to fairGATE" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Introduction to fairGATE} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include=FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 7, fig.height = 5, error = TRUE ) ``` # 1. Introduction The **fairGATE** package provides a complete pipeline for training and evaluating a Gated Neural Network (GNN) designed to mitigate demographic bias in predictive modelling. The package implements a fairness-aware GNN that uses a custom loss function to enforce the *Equalized Odds* fairness criterion by minimising the variance in True Positive and False Positive Rates across subgroups. This vignette demonstrates the full workflow using the GENDEP dataset to predict antidepressant response, focusing on fairness across gender subgroups. # 2. The fairGATE Workflow The package is built around a logical sequence of functions: - `prepare_data()`: Cleans and prepares the input data.\ - `train_gnn()`: Trains the Gated Neural Network.\ - `analyse_gnn_results()`: Conducts performance and gate analysis.\ - `analyse_experts()`: Performs analysis of expert specialisation.\ - `plot_sankey()`: Visualises the model's patient routing behaviour. ## 2.1. Loading Data First, we load the necessary libraries and the dataset. We use a small in-package sample of the UCI Adult dataset. Outcome is \*\*income\*\* (1 = \>50K, 0 = ≤50K); protected attribute is \*\*sex\*\*. ```{r} library(fairGATE) library(dplyr) library(readxl) # Loading the UCI Adult Dataset data("adult_ready_small", package = "fairGATE") adult_data <- adult_ready_small adult <- adult_data %>% mutate( across(where(is.character), ~ trimws(.x)), income = as.integer(income) ) ``` ## 2.2. Step 1: Prepare the Data We use `prepare_data()` to process the raw data for the Male/Female analysis. ```{r} # Dropping unwanted cols (i.e. numeric cols and those with high multicolinearity) cols_to_drop <- c("subjectid", "Row.names") # Ensure to perform other preprocessing steps such as one-hot endoing etc # Fully prepared data goes here prepared <- fairGATE::prepare_data( data = adult, outcome_var = "income", group_var = "sex", cols_to_remove= cols_to_drop ) ``` ```{r, include = FALSE} # --- Safety block: clean any non-finite or zero-variance columns before training --- X <- prepared$X fix_col <- function(x) { x[!is.finite(x)] <- NA if (all(is.na(x))) return(rep(0, length(x))) x[is.na(x)] <- stats::median(x, na.rm = TRUE) x } # Replace any non-finite values bad <- colSums(!is.finite(X)) > 0 if (any(bad)) X[, bad] <- apply(X[, bad, drop = FALSE], 2, fix_col) # Drop zero-variance columns (these can cause NaNs on scaling) zv <- apply(X, 2, function(v) sd(v, na.rm = TRUE) == 0) if (any(zv)) X <- X[, !zv, drop = FALSE] # Update prepared object prepared$X <- X prepared$feature_names <- colnames(X) # Quick sanity check stopifnot(sum(!is.finite(prepared$X)) == 0, ncol(prepared$X) > 0) ``` ## 2.3. Train a small demo model ```{r train-demo, results='hide', message=FALSE, warning=FALSE} # Train a small Gated Neural Network trained_model <- fairGATE::train_gnn( prepared_data = prepared, run_tuning = FALSE, # skip tuning for speed best_params = list( lr = 0.01, hidden_dim = 16, dropout_rate = 0.1, lambda = 0.0, temperature = 1.0 ), num_repeats = 2, # very short repeated split epochs = 20, # fast CRAN-safe runtime verbose = FALSE ) ``` ## 2.4. Step 3: Analyse Basic Performance and Gates With the results loaded, we run `analyse_gnn_results()` to generate all the standard performance plots and gate analyses. ```{r} # Run basic analysis basic_analyses <- analyse_gnn_results( gnn_results = trained_model, prepared_data = prepared ) # --- View all plots from the basic analysis --- cat("## ROC Curve\n") print(basic_analyses$roc_plot) cat("\n## Calibration Plot\n") print(basic_analyses$calibration_plot) cat("\n## Gate Weight Distribution\n") print(basic_analyses$gate_density_plot) cat("\n## Gate Entropy Distribution\n") print(basic_analyses$entropy_density_plot) ``` ## 2.5. Step 4: Analyse Expert Specialisation Now, we use `analyse_experts()` to investigate how the different expert networks have specialised their learning. The `analyse_experts()` function summarises expert weights per subgroup, compares mean importance across groups, and produces difference or multi-group plots. ```{r} exp_res <- analyse_experts( gnn_results = trained_model, # from train_gnn() prepared_data = prepared, # from prepare_data() top_n_features = 15, # number of top features to visualise verbose = TRUE ) # View the main objects returned names(exp_res) #> [1] "all_weights" "means_by_group_wide" #> [3] "pairwise_differences" "difference_plot" #> [5] "multi_group_plot" "top_features_multi" # View first few feature importances head(exp_res$means_by_group_wide) # Example: view one pairwise difference table names(exp_res$pairwise_differences) #> [1] "Female_vs_Male" head(exp_res$pairwise_differences[[1]]) # Visualise feature specialisation if (!is.null(exp_res$difference_plot)) print(exp_res$difference_plot) if (!is.null(exp_res$multi_group_plot)) print(exp_res$multi_group_plot) ``` ## 2.6. Step 5: Visualise Patient Routing with a Sankey Plot We can use `plot_sankey()` to create the key visualisation from the research paper, showing how patients are routed through the model. The Sankey shows how patients flow from actual subgroup -> assigned expert. It auto-derives subgroup labels from `prepared` and subject IDs from `trained_model`. ```{r} # Generate and print the Sankey plot p <- plot_sankey( prepared_data = prepared, # from prepare_data() gnn_results = trained_model, # from train_gnn() expert_results = exp_res, # from analyse_experts() verbose = TRUE ) print(p) ``` ## 2.7. (Optional) Export data for Fairness 360 or external fairness analysis The export_f360() function writes a clean plug-and-play CSV for use in IBM Fairness 360, containing columns for subject IDs, true labels, predicted probabilities, and the sensitive attribute. It can also include gate probabilities if desired. ```{r f360_export, eval = FALSE, message = FALSE} export_f360_csv( gnn_results = trained_model, # from train_gnn() prepared_data = prepared, # from prepare_data() path = "outputs/fairness360_input.csv", include_gate_cols = TRUE, # include expert routing probabilities threshold = 0.5, # classification threshold for binary outcome verbose = TRUE ) ```