Cross-Validation Explained (Example)
Everyone who deals with machine learning methods comes across the term cross-validation at some point. In this blog post, we provide you with a brief introduction to cross-validation. In further blog posts, we focus on the concrete cross-validation techniques and their implementation in the R programming language and Python.
This post is structured as follows:
Let’s dive right into cross-validation!
Let’s assume your aim is to model some data to make predictions or classifications of new data points. A prominent example is that you receive emails, some of which are spam. Using your historic emails, all labelled as spam or not-spam, you can fit a model for classifying whether a new email is spam or not based on the text of the mail. For that, you can use various modelling techniques, e.g. a binary or multinomial logistic regression model, random forest, or Support Vector Machines.
For a classification or prediction model, you want to assess how well it actually performs for new data points. That is, you want to know how well it generalizes to unseen data. For that, you typically randomly split your available data into training and validation data. Based on the training data, you fit your model. Based on the validation data, you make predictions or classifications using your fitted model. By comparing the actual outcomes in the validation data (e.g. the labels indicating whether an email is spam or not) with the predictions or classifications of the model, you get an indication of how well the model actually performs for predicting/classifying new data points. To put it differently, based on the validation data you can calculate an estimator of how well a model generalizes for out-of-sample predictions/classifications.
You can use this technique of splitting your data in training and validation data to compare competing models, choose your model parameters, and perform model selection. For example, for a random forest algorithm you can use model validation to decide how many trees you want to combine, how deeply nested you allow these trees to grow, and the number of features to randomly select in each tree. Note that the part of the data used for validation is also often referred to as the hold out set.
Often, the data is not divided into two, but three subsets: Training data, validation data, and test data. You can use the validation data to tune your model, like selecting the parameters of a random forest algorithm, or – more generally – to select a model among candidates. Subsequently, you can use the test set, which was neither used for modelling the data nor for evaluating the model for different parameter sets, to get an idea of how well the final model predicts or classifies new data points.
For a more profound understanding of the theoretical background of model validation and the estimation of the generalization error of a model, we recommend you to take a look at:
- Principles and Theory for Data Mining and Machine Learning by Bertrand Clarke, Ernest Fokoue, and Hao Helen Zhang (especially Chapter 1.3.2)
- Understanding Machine Learning: From Theory to Algorithms by Shai Shalev-Shwartz and Shai Ben-David (especially Chapter 11), explain how validation estimates the true risk of an algorithm
- The Elements of Statistical Learning by Trevor Hastie, Robert Tibshirani, and Jerome Friedman (especially Chapter 7)
So far, we only very generally stated that we could use the validation data for validating a model. But how do we actually do that? We decide for a loss function and calculate the value of this loss function based on the validation data. In the loss function, we plug in the actually observed labels/values in the validation data and the model predictions and calculate the chosen distance between them.
The choice of a loss function depends on the data at hand. A prominent loss function is the squared error loss (=L2 loss), see e.g., The Elements of Statistical Learning, Chapter 2.4. The squared error loss sums up the squared differences between the actual values and their predictions. Another prominent loss function is the L1 loss, which sums up the absolute differences between values and predictions.
A word of caution: As is illustrated in detail in The Elements of Statistical Learning, Chapter 7.10.2, note that that one should be careful not to make analysis based on the full data, like selecting auxiliary variables for a prediction model, and afterwards split the data into training and validation. The previously acquired information on the full dataset can lead to false conclusions about the validation data, like biased estimates of the generalization error of a model.
In an ideal world, our data at hand is large enough that we can nicely make a random split into training and validation or training, validation, and test data. However, in real applications, we often do not have many data points available. We know that a model typically gets more accurate when it is fed with more data. For example, see this post about sample sizes for machine learning algorithms. Therefore, we have an incentive to use as much of the data as possible to train the model(s). On the other hand, our model performance estimates based on the validation data are also more accurate (smaller variance) the more data we use.
To overcome the trade-off between using as much data as possible for two distinct subsets for training and validation, we can use cross-validation. In cross-validation, we repeat the process of randomly splitting the data in training and validation data several times and decide for a measure to combine the results of the different splits. Note that cross-validation is typically only used for model and validation data, and the model testing is still done on a separate test set.
Types of Cross-Validation
Prominent examples for cross-validation are:
- K-fold cross-validation
- Leave-one-out cross-validation
We will make extra blog posts for both procedures.
Video & Summary
In case you want to learn more about cross-validation, you may watch the following video of the StatQuest with Josh Starmer YouTube channel. In the video, the speaker explains the concept of cross-validation based on an easy-to-follow example:
Please accept YouTube cookies to play this video. By accepting you will be accessing content from YouTube, a service provided by an external third party.
If you accept this notice, your choice will be saved and the page will refresh.
Summary: With this article, we gave a brief introduction to the general idea of cross-validation. If you have additional questions, please let me know in the comments below. Furthermore, please subscribe to the email newsletter to get updates on new tutorials.
This page was created in collaboration with Anna-Lena Wölwer. Have a look at Anna-Lena’s author page to get further details about her academic background and the other articles she has written for Statistics Globe.