Stratified k-fold Cross-Validation in R (Example)

 

In this R tutorial, you’ll learn how to draw the folds for cross-validation stratified by class. Stratified folds are especially useful when your variable of interest is imbalanced. That is, when the class frequencies differ to a great extent. Stratification ensures that the class frequencies in all folds are similar to the frequencies in the original data.

The tutorial looks as follows:

Let’s get started.

 

Exemplifying Data & Add-On Packages

We use the data.table and the rpart package. Note that we have an extra overview blog post about the data.table package here.

install.packages("data.table")                                    # Install data.table package
library("data.table")                                             # Load data.table package                
 
install.packages("rpart")                                         # Install rpart package
library("rpart")                                                  # Load rpart package

We create some example data according to a multinomial logistic regression model. That is: We create some auxiliary matrix X. From that matrix, we apply a generalized linear model with a multinomial link to create a class variable y with three classes.

For the formulas of the generalized linear model with multinomial link, we refer you to Agresti (2012) Categorical Data Analysis, especially Chapter 8.

N <- 500                                                           # Number of observations
set.seed(345)                                                      # Seed for reproducible results
X <- data.table(factor(sample(letters[1:20], N, replace = TRUE)),  # Create auxiliary information
                factor(sample(letters[1:20], N, replace = TRUE)),
                factor(sample(letters[1:20], N, replace = TRUE)),
                factor(sample(letters[1:20], N, replace = TRUE)))
X_mod_mat <- model.matrix(~., X)                                   # Create model matrix
beta_vec1 <- c(0, rep(.2, 76))                                     # Set beta coefficients for class 1
beta_vec2 <- c(0, rep(.3, 76))                                     # Set beta coefficients for class 2
eta_1     <- X_mod_mat %*% beta_vec1                               # Linear predictor for class 1
eta_2     <- X_mod_mat %*% beta_vec2                               # Linear predictor for class 2
probs     <- exp(cbind(0, eta_1, eta_2))                           # Combine linear predictors
probs     <- probs / rowSums(probs)                                # Calculate probabilities
cum_probs <- t(apply(probs, 1, cumsum))                            # Cumulative probabilities

From the cumulative probabilities, we can draw actual classes.

y <- factor(apply(cum_probs, 1, 
                  function (x) { 
                    min(which(x > runif(1))) 
                  } ))                                             # Draw concrete classes
prop.table(table(y)) * 100                                         # Relative frequencies of generated classes
# y
#    1    2    3 
# 14.2 35.0 50.8

As you can see: Based on the previous output of the RStudio console, we created variable y, which consists of three classes of quite unequal frequencies.

 

Example: Stratified k-Fold Cross-Validation for Classification Tree

In this section, I’ll demonstrate how to create stratified folds for cross-validation. For a general motivation for cross-validation take a look at this post first. We also have a blog post about k-fold cross-validation here.

In the following small example, we want to set one of the parameters of R function rpart from the rpart package. One of the options of rpart, which you can see using ?rpart.control, is minsplit: The minimum number of observations that must exist in a node in order for a split to be attempted. The default value is set to 20. We use cross-validation to see how rpart performs under different values of minsplit. We consider the following integers.

choices_minsplit <- c(5, 10, 20, 40, 60)                           # Choices for classification algorithm

We define 5 splits of the data, stratified according to the imbalanced class frequencies. We prefer to have stratified folds as our variable of interest, y, is imbalanced. For more information on machine learning under imbalanced data, we refer you to the book Learning from Imbalanced Data Sets from 2018.

k_folds <- 5                                                       # Number of folds in k-fold cross-validation
 
data_ex <- data.table(y, X, fold = 0)                              # Add column for fold identifier
 
for (y_i in levels(y) ) {                                          # Create stratified folds
  nrow_i       <- nrow(data_ex[y == y_i,])
  n_per_fold_i <- ceiling(nrow_i / k_folds)
  data_ex[y == y_i, fold := sample(rep(1:k_folds, n_per_fold_i), nrow_i, replace = FALSE)]
}
 
data_ex[, table(y, fold)]                                          # Number of classes per fold
#    fold
# y    1  2  3  4  5
#   1 14 14 15 15 13
#   2 35 35 35 35 35
#   3 51 51 50 51 51

You can see that we achieved to have roughly the same class frequencies in the five folds. Try to just sample observations to the folds without the stratification to see how different the class frequencies become!

Now let us make a small example. Take the 5 folds to perform a cross-validation for choosing the minsplit integer of the rpart algorithm. In this small example, we take the overall accuracy to evaluate rpart under the different minsplit integers. Exercise for you: There are other measures more suitable for imbalanced data. Change the code to one of the evaluation criteria you found. Do the results differ?

result_cv <- sapply(1:k_folds, # Apply k-fold cross-validation
                    function (k_folds_i) {
 
                      ind_train <- data_ex$fold != k_folds_i
 
                      result_cv_i <- sapply(choices_minsplit,     # For each candidate of minsplit value: Calculate classification tree
                                            function (option_i) {
 
                                              model <- rpart::rpart(y ~ ., # Fit classification tree
                                                                    method = "class",
                                                                    data = cbind(y, X)[ind_train, ],
                                                                    control = rpart.control(minsplit  = option_i))
 
                                              confMat <- caret::confusionMatrix(data      = rpart:::predict.rpart(model, X[!ind_train, ], type = "class"),     # Calculate confusion matrix
                                                                                reference = y[!ind_train],
                                                                                mode      = "everything")
 
                                              confMat$overall["Accuracy"] # Return evaluation criteria
 
                                            })
 
                      names(result_cv_i) <- paste0("minsplit ", choices_minsplit)
                      result_cv_i
                    })
 
round(rowMeans(result_cv), digits = 2) # Averaging the results over the k folds
# minsplit 5 minsplit 10 minsplit 20 minsplit 40 minsplit 60 
#       0.39        0.40        0.42        0.42        0.42

In this small example, there are only small differences in the accuracy of rpart with respect to different minsplit values.

Now it is up to you: Have a try, see whether the accuracy (or your evaluation measure of choice) of rpart is sensitive with respect to the other parameters of the algorithm (available under ?rpart.control).

 

Video & Further Resources

Have a look at the following video on my YouTube channel. I demonstrate the R programming code of this article in the video:

 

The YouTube video will be added soon.

 

Furthermore, you may read the related posts on my homepage. I have released several articles that are related to the data.table package, which we used in this post, already.

 

Summary: You have learned in this article how to do stratification for the k-folds in cross-validation in R programming. In case you have any further questions, tell me about it in the comments below.

 

Anna-Lena Wölwer Survey Statistician & R Programmer

This page was created in collaboration with Anna-Lena Wölwer. Have a look at Anna-Lena’s author page to get further information about her academic background and the other articles she has written for Statistics Globe.

 

Subscribe to the Statistics Globe Newsletter

Get regular updates on the latest tutorials, offers & news at Statistics Globe.
I hate spam & you may opt out anytime: Privacy Policy.


Leave a Reply

Your email address will not be published. Required fields are marked *

Fill out this field
Fill out this field
Please enter a valid email address.

Top