Stratified k-fold Cross-Validation in R (Example)
In this R tutorial, you’ll learn how to draw the folds for cross-validation stratified by class. Stratified folds are especially useful when your variable of interest is imbalanced. That is, when the class frequencies differ to a great extent. Stratification ensures that the class frequencies in all folds are similar to the frequencies in the original data.
The tutorial looks as follows:
Let’s get started.
Exemplifying Data & Add-On Packages
We use the data.table and the rpart package. Note that we have an extra overview blog post about the data.table package here.
install.packages("data.table") # Install data.table package library("data.table") # Load data.table package install.packages("rpart") # Install rpart package library("rpart") # Load rpart package
We create some example data according to a multinomial logistic regression model. That is: We create some auxiliary matrix X. From that matrix, we apply a generalized linear model with a multinomial link to create a class variable y with three classes.
For the formulas of the generalized linear model with multinomial link, we refer you to Agresti (2012) Categorical Data Analysis, especially Chapter 8.
N <- 500 # Number of observations set.seed(345) # Seed for reproducible results X <- data.table(factor(sample(letters[1:20], N, replace = TRUE)), # Create auxiliary information factor(sample(letters[1:20], N, replace = TRUE)), factor(sample(letters[1:20], N, replace = TRUE)), factor(sample(letters[1:20], N, replace = TRUE))) X_mod_mat <- model.matrix(~., X) # Create model matrix beta_vec1 <- c(0, rep(.2, 76)) # Set beta coefficients for class 1 beta_vec2 <- c(0, rep(.3, 76)) # Set beta coefficients for class 2 eta_1 <- X_mod_mat %*% beta_vec1 # Linear predictor for class 1 eta_2 <- X_mod_mat %*% beta_vec2 # Linear predictor for class 2 probs <- exp(cbind(0, eta_1, eta_2)) # Combine linear predictors probs <- probs / rowSums(probs) # Calculate probabilities cum_probs <- t(apply(probs, 1, cumsum)) # Cumulative probabilities
From the cumulative probabilities, we can draw actual classes.
y <- factor(apply(cum_probs, 1, function (x) { min(which(x > runif(1))) } )) # Draw concrete classes prop.table(table(y)) * 100 # Relative frequencies of generated classes # y # 1 2 3 # 14.2 35.0 50.8
As you can see: Based on the previous output of the RStudio console, we created variable y, which consists of three classes of quite unequal frequencies.
Example: Stratified k-Fold Cross-Validation for Classification Tree
In this section, I’ll demonstrate how to create stratified folds for cross-validation. For a general motivation for cross-validation take a look at this post first. We also have a blog post about k-fold cross-validation here.
In the following small example, we want to set one of the parameters of R function rpart from the rpart package. One of the options of rpart, which you can see using ?rpart.control, is minsplit: The minimum number of observations that must exist in a node in order for a split to be attempted. The default value is set to 20. We use cross-validation to see how rpart performs under different values of minsplit. We consider the following integers.
choices_minsplit <- c(5, 10, 20, 40, 60) # Choices for classification algorithm
We define 5 splits of the data, stratified according to the imbalanced class frequencies. We prefer to have stratified folds as our variable of interest, y, is imbalanced. For more information on machine learning under imbalanced data, we refer you to the book Learning from Imbalanced Data Sets from 2018.
k_folds <- 5 # Number of folds in k-fold cross-validation data_ex <- data.table(y, X, fold = 0) # Add column for fold identifier for (y_i in levels(y) ) { # Create stratified folds nrow_i <- nrow(data_ex[y == y_i,]) n_per_fold_i <- ceiling(nrow_i / k_folds) data_ex[y == y_i, fold := sample(rep(1:k_folds, n_per_fold_i), nrow_i, replace = FALSE)] } data_ex[, table(y, fold)] # Number of classes per fold # fold # y 1 2 3 4 5 # 1 14 14 15 15 13 # 2 35 35 35 35 35 # 3 51 51 50 51 51
You can see that we achieved to have roughly the same class frequencies in the five folds. Try to just sample observations to the folds without the stratification to see how different the class frequencies become!
Now let us make a small example. Take the 5 folds to perform a cross-validation for choosing the minsplit integer of the rpart algorithm. In this small example, we take the overall accuracy to evaluate rpart under the different minsplit integers. Exercise for you: There are other measures more suitable for imbalanced data. Change the code to one of the evaluation criteria you found. Do the results differ?
result_cv <- sapply(1:k_folds, # Apply k-fold cross-validation function (k_folds_i) { ind_train <- data_ex$fold != k_folds_i result_cv_i <- sapply(choices_minsplit, # For each candidate of minsplit value: Calculate classification tree function (option_i) { model <- rpart::rpart(y ~ ., # Fit classification tree method = "class", data = cbind(y, X)[ind_train, ], control = rpart.control(minsplit = option_i)) confMat <- caret::confusionMatrix(data = rpart:::predict.rpart(model, X[!ind_train, ], type = "class"), # Calculate confusion matrix reference = y[!ind_train], mode = "everything") confMat$overall["Accuracy"] # Return evaluation criteria }) names(result_cv_i) <- paste0("minsplit ", choices_minsplit) result_cv_i }) round(rowMeans(result_cv), digits = 2) # Averaging the results over the k folds # minsplit 5 minsplit 10 minsplit 20 minsplit 40 minsplit 60 # 0.39 0.40 0.42 0.42 0.42
In this small example, there are only small differences in the accuracy of rpart with respect to different minsplit values.
Now it is up to you: Have a try, see whether the accuracy (or your evaluation measure of choice) of rpart is sensitive with respect to the other parameters of the algorithm (available under ?rpart.control).
Video & Further Resources
Have a look at the following video on my YouTube channel. I demonstrate the R programming code of this article in the video:
The YouTube video will be added soon.
Furthermore, you may read the related posts on my homepage. I have released several articles that are related to the data.table package, which we used in this post, already.
- Join Multiple data.tables in R (6 Examples)
- Use lapply Function for data.table in R (4 Examples)
- Create Empty data.table with Column Names in R (2 Examples)
- Reshape data.table in R (3 Examples)
- R Programming Tutorials
Summary: You have learned in this article how to do stratification for the k-folds in cross-validation in R programming. In case you have any further questions, tell me about it in the comments below.
This page was created in collaboration with Anna-Lena Wölwer. Have a look at Anna-Lena’s author page to get further information about her academic background and the other articles she has written for Statistics Globe.
Statistics Globe Newsletter