Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
677 views
in Technique[技术] by (71.8m points)

machine learning - Early stopping with Keras and sklearn GridSearchCV cross-validation

I wish to implement early stopping with Keras and sklean's GridSearchCV.

The working code example below is modified from How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras. The data set may be downloaded from here.

The modification adds the Keras EarlyStopping callback class to prevent over-fitting. For this to be effective it requires the monitor='val_acc' argument for monitoring validation accuracy. For val_acc to be available KerasClassifier requires the validation_split=0.1 to generate validation accuracy, else EarlyStopping raises RuntimeWarning: Early stopping requires val_acc available!. Note the FIXME: code comment!

Note we could replace val_acc by val_loss!

Question: How can I use the cross-validation data set generated by the GridSearchCV k-fold algorithm instead of wasting 10% of the training data for an early stopping validation set?

# Use scikit-learn to grid search the learning rate and momentum
import numpy
from sklearn.model_selection import GridSearchCV
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.optimizers import SGD

# Function to create model, required for KerasClassifier
def create_model(learn_rate=0.01, momentum=0):
    # create model
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    # Compile model
    optimizer = SGD(lr=learn_rate, momentum=momentum)
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

# Early stopping
from keras.callbacks import EarlyStopping
stopper = EarlyStopping(monitor='val_acc', patience=3, verbose=1)

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = KerasClassifier(
    build_fn=create_model,
    epochs=100, batch_size=10,
    validation_split=0.1, # FIXME: Instead use GridSearchCV k-fold validation data.
    verbose=2)
# define the grid search parameters
learn_rate = [0.01, 0.1]
momentum = [0.2, 0.4]
param_grid = dict(learn_rate=learn_rate, momentum=momentum)
grid = GridSearchCV(estimator=model, param_grid=param_grid, verbose=2, n_jobs=1)

# Fitting parameters
fit_params = dict(callbacks=[stopper])
# Grid search.
grid_result = grid.fit(X, Y, **fit_params)

# summarize results
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

[Answer after the question was edited & clarified:]

Before rushing into implementation issues, it is always a good practice to take some time to think about the methodology and the task itself; arguably, intermingling early stopping with the cross validation procedure is not a good idea.

Let's make up an example to highlight the argument.

Suppose that you indeed use early stopping with 100 epochs, and 5-fold cross validation (CV) for hyperparameter selection. Suppose also that you end up with a hyperparameter set X giving best performance, say 89.3% binary classification accuracy.

Now suppose that your second-best hyperparameter set, Y, gives 89.2% accuracy. Examining closely the individual CV folds, you see that, for your best case X, 3 out of the 5 CV folds exhausted the max 100 epochs, while in the other 2 early stopping kicked in, say in 95 and 93 epochs respectively.

Now imagine that, examining your second-best set Y, you see that again 3 out of the 5 CV folds exhausted the 100 epochs, while the other 2 both stopped early enough at ~ 80 epochs.

What would be your conclusion from such an experiment?

Arguably, you would have found yourself in an inconclusive situation; further experiments might reveal which is actually the best hyperparameter set, provided of course that you would have thought to look into these details of the results in the first place. And needless to say, if all this was automated through a callback, you might have missed your best model despite the fact that you would have actually tried it.


The whole CV idea is implicitly based on the "all other being equal" argument (which of course is never true in practice, only approximated in the best possible way). If you feel that the number of epochs should be a hyperparameter, just include it explicitly in your CV as such, rather than inserting it through the back door of early stopping, thus possibly compromising the whole process (not to mention that early stopping has itself a hyperparameter, patience).

Not intermingling these two techniques doesn't mean of course that you cannot use them sequentially: once you have obtained your best hyperparameters through CV, you can always employ early stopping when fitting the model in your whole training set (provided of course that you do have a separate validation set).


The field of deep neural nets is still (very) young, and it is true that it has yet to establish its "best practice" guidelines; add the fact that, thanks to an amazing community, there are all sort of tools available in open source implementations, and you can easily find yourself into the (admittedly tempting) position of mixing things up just because they happen to be available. I am not necessarily saying that this is what you are attempting to do here - I am just urging for more caution when combining ideas that may have not been designed to work along together...


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...