How to split a dataset

In machine learning it is pretty obvious to me that you need to split your dataset into 2 parts:

  • a training set that you can use to train your model and find optimal parameters
  • a test set that you can use to test your trained model and see how well it generalises.

It is important that the test data is┬ánever used during the training phase. Using “unseen” data is what allows us to test how well our model generalises. It makes sure your model doesn’t overfit.

However if your model uses hyper parameters you need a validation set as well. The validation set is used to test the hyper parameter values. Typically you wouldn’t want to use the test set to adjust the hyper parameters as you want to keep this data unseen to measure your model generalisation capability.

People usually tend to start with a 80-20% split (80% training set – 20% test set) and split the training set once more into a 80-20% ratio to create the validation set. It’s usually a good start but it’s more a rule of thumb than anything else and you may want to adjust the splits depending on the amount of available data.

The key principle to understand is that the more samples the lower the variance. So you need the training set to be big enough to achieve low variance over the model parameters.

Similarly for test data you also want enough data to observe low variance among the performance results.

The idea is to split the data to achieve low variance in both cases. If you dataset is big enough to achieve low variance on the training parameters, increasing the training set any further won’t help much but will increase the training time.

On the other hand if your dataset is too small to achieve low variance when split into training/validation/test sets, you may use cross validation techniques like “k-fold“.

With “k-fold” the dataset is split into k mutually exclusive subsets (they don’t overlap). Then you train your model k times using the ith subset as the test dataset and the remaining of the dataset as the training (and validation) set.