You can extract the important lines from the load_model
and save_model
functions.
For saving optimizer states, in save_model
:
# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
optimizer_weights_group = f.create_group('optimizer_weights')
weight_values = K.batch_get_value(symbolic_weights)
For loading optimizer states, in load_model
:
# Set optimizer weights.
if 'optimizer_weights' in f:
# Build train function (to get weight updates).
if isinstance(model, Sequential):
model.model._make_train_function()
else:
model._make_train_function()
# ...
try:
model.optimizer.set_weights(optimizer_weight_values)
Combining the lines above, here's an example:
- First fit the model for 5 epochs.
X, y = np.random.rand(100, 50), np.random.randint(2, size=100)
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(X, y, epochs=5)
Epoch 1/5
100/100 [==============================] - 0s 4ms/step - loss: 0.7716
Epoch 2/5
100/100 [==============================] - 0s 64us/step - loss: 0.7678
Epoch 3/5
100/100 [==============================] - 0s 82us/step - loss: 0.7665
Epoch 4/5
100/100 [==============================] - 0s 56us/step - loss: 0.7647
Epoch 5/5
100/100 [==============================] - 0s 76us/step - loss: 0.7638
- Now save the weights and optimizer states.
model.save_weights('weights.h5')
symbolic_weights = getattr(model.optimizer, 'weights')
weight_values = K.batch_get_value(symbolic_weights)
with open('optimizer.pkl', 'wb') as f:
pickle.dump(weight_values, f)
- Rebuild the model in another python session, and load weights.
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.load_weights('weights.h5')
model._make_train_function()
with open('optimizer.pkl', 'rb') as f:
weight_values = pickle.load(f)
model.optimizer.set_weights(weight_values)
- Continue model training.
model.fit(X, y, epochs=5)
Epoch 1/5
100/100 [==============================] - 0s 674us/step - loss: 0.7629
Epoch 2/5
100/100 [==============================] - 0s 49us/step - loss: 0.7617
Epoch 3/5
100/100 [==============================] - 0s 49us/step - loss: 0.7611
Epoch 4/5
100/100 [==============================] - 0s 55us/step - loss: 0.7601
Epoch 5/5
100/100 [==============================] - 0s 49us/step - loss: 0.7594