I've had this problem myself: (1) I want to use a custom training loop; (2) I don't want to lose the bells and whistles Keras gives me in terms of callbacks; (3) I don't want to re-implement them all myself. Tensorflow has a design philosophy of allowing a developer to gradually opt-in to its more low-level APIs. As @HyeonPhilYoun notes in his comment below, the official documentation for tf.keras.callbacks.Callback
gives an example of what we're looking for.
The following has worked for me, but can be improved by reverse engineering tf.keras.Model
.
The trick is to use tf.keras.callbacks.CallbackList
and then manually trigger its lifecycle events from within your custom training loop. This example uses tqdm
to give attractive progress bars, but CallbackList
has a progress_bar
initialization argument that can let you use the defaults. training_model
is a typical instance of tf.keras.Model
.
from tqdm.notebook import tqdm, trange
# Populate with typical keras callbacks
_callbacks = []
callbacks = tf.keras.callbacks.CallbackList(
_callbacks, add_history=True, model=training_model)
logs = {}
callbacks.on_train_begin(logs=logs)
# Presentation
epochs = trange(
max_epochs,
desc="Epoch",
unit="Epoch",
postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}")
epochs.set_postfix(loss=0, accuracy=0)
# Get a stable test set so epoch results are comparable
test_batches = batches(test_x, test_Y)
for epoch in epochs:
callbacks.on_epoch_begin(epoch, logs=logs)
# I like to formulate new batches each epoch
# if there are data augmentation methods in play
training_batches = batches(x, Y)
# Presentation
enumerated_batches = tqdm(
enumerate(training_batches),
desc="Batch",
unit="batch",
postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}",
position=1,
leave=False)
for (batch, (x, y)) in enumerated_batches:
training_model.reset_states()
callbacks.on_batch_begin(batch, logs=logs)
callbacks.on_train_batch_begin(batch, logs=logs)
logs = training_model.train_on_batch(x=x, y=Y, return_dict=True)
callbacks.on_train_batch_end(batch, logs=logs)
callbacks.on_batch_end(batch, logs=logs)
# Presentation
enumerated_batches.set_postfix(
loss=float(logs["loss"]),
accuracy=float(logs["accuracy"]))
for (batch, (x, y)) in enumerate(test_batches):
training_model.reset_states()
callbacks.on_batch_begin(batch, logs=logs)
callbacks.on_test_batch_begin(batch, logs=logs)
logs = training_model.test_on_batch(x=x, y=Y, return_dict=True)
callbacks.on_test_batch_end(batch, logs=logs)
callbacks.on_batch_end(batch, logs=logs)
# Presentation
epochs.set_postfix(
loss=float(logs["loss"]),
accuracy=float(logs["accuracy"]))
callbacks.on_epoch_end(epoch, logs=logs)
# NOTE: This is a decent place to check on your early stopping
# callback.
# Example: use training_model.stop_training to check for early stopping
callbacks.on_train_end(logs=logs)
# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
if isinstance(cb, tf.keras.callbacks.History):
history_object = cb
assert history_object is not None
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…