Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
143 views
in Technique[技术] by (71.8m points)

machine learning - How to use k-fold cross validation in a neural network

We are writing a small ANN which is supposed to categorize 7000 products into 7 classes based on 10 input variables.

In order to do this we have to use k-fold cross validation but we are kind of confused.

We have this excerpt from the presentation slide:

k-fold cross validation diagram

What are exactly the validation and test sets?

From what we understand is that we run through the 3 training sets and adjust the weights (single epoch). Then what do we do with the validation? Because from what I understand is that the test set is used to get the error of the network.

What happens next is also confusing to me. When does the crossover take place?

If it's not too much to ask, a bullet list of step would be appreciated

question from:https://stackoverflow.com/questions/25889637/how-to-use-k-fold-cross-validation-in-a-neural-network

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You seem to be a bit confused (I remember I was too) so I am going to simplify things for you. ;)

Sample Neural Network Scenario

Whenever you are given a task such as devising a neural network you are often also given a sample dataset to use for training purposes. Let us assume you are training a simple neural network system Y = W · X where Y is the output computed from calculating the scalar product (·) of the weight vector W with a given sample vector X. Now, the naive way to go about this would be using the entire dataset of, say, 1000 samples to train the neural network. Assuming that the training converges and your weights stabilise you can then safely say that you network will correctly classify the training data. But what happens to the network if presented with previously unseen data? Clearly the purpose of such systems is to be able to generalise and correctly classify data other than the one used for training.

Overfitting Explained

In any real-world situation, however, previously-unseen/new data is only available once your neural network is deployed in a, let's call it, production environment. But since you have not tested it adequately you are probably going to have a bad time. :) The phenomenon by which any learning system matches its training set almost perfectly but constantly fails with unseen data is called overfitting.

The Three Sets

Here come in the validation and testing parts of the algorithm. Let's go back to the original dataset of 1000 samples. What you do is you split it into three sets -- training, validation and testing (Tr, Va and Te) -- using carefully selected proportions. (80-10-10)% is usually a good proportion, where:

  • Tr = 80%
  • Va = 10%
  • Te = 10%

Training and Validation

Now what happens is that the neural network is trained on the Tr set and its weights are correctly updated. The validation set Va is then used to compute the classification error E = M - Y using the weights resulting from the training, where M is the expected output vector taken from the validation set and Y is the computed output resulting from the classification (Y = W * X). If the error is higher than a user-defined threshold then the whole training-validation epoch is repeated. This training phase ends when the error computed using the validation set is deemed low enough.

Smart Training

Now, a smart ruse here is to randomly select which samples to use for training and validation from the total set Tr + Va at each epoch iteration. This ensures that the network will not over-fit the training set.

Testing

The testing set Te is then used to measure the performance of the network. This data is perfect for this purpose as it was never used throughout the training and validation phase. It is effectively a small set of previously unseen data, which is supposed to mimic what would happen once the network is deployed in the production environment.

The performance is again measured in term of classification error as explained above. The performance can also (or maybe even should) be measured in terms of precision and recall so as to know where and how the error occurs, but that's the topic for another Q&A.

Cross-Validation

Having understood this training-validation-testing mechanism, one can further strengthen the network against over-fitting by performing K-fold cross-validation. This is somewhat an evolution of the smart ruse I explained above. This technique involves performing K rounds of training-validation-testing on, different, non-overlapping, equally-proportioned Tr, Va and Te sets.

Given k = 10, for each value of K you will split your dataset into Tr+Va = 90% and Te = 10% and you will run the algorithm, recording the testing performance.

k = 10
for i in 1:k
     # Select unique training and testing datasets
     KFoldTraining <-- subset(Data)
     KFoldTesting <-- subset(Data)

     # Train and record performance
     KFoldPerformance[i] <-- SmartTrain(KFoldTraining, KFoldTesting)

# Compute overall performance
TotalPerformance <-- ComputePerformance(KFoldPerformance)

Overfitting Shown

I am taking the world-famous plot below from wikipedia to show how the validation set helps prevent overfitting. The training error, in blue, tends to decrease as the number of epochs increases: the network is therefore attempting to match the training set exactly. The validation error, in red, on the other hand follows a different, u-shaped profile. The minimum of the curve is when ideally the training should be stopped as this is the point at which the training and validation error are lowest.

Overfitting reduced by validating neural network

References

For more references this excellent book will give you both a sound knowledge of machine learning as well as several migraines. Up to you to decide if it's worth it. :)


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...