Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
254 views
in Technique[技术] by (71.8m points)

python - Applying callbacks in a custom training loop in Tensorflow 2.0

I'm writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we pass them as an argument to the 'fit' method, but can't find resources on how to use these callbacks in the custom training loop. I'm adding the code for the custom training loop from the Tensorflow documentation:

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I've had this problem myself: (1) I want to use a custom training loop; (2) I don't want to lose the bells and whistles Keras gives me in terms of callbacks; (3) I don't want to re-implement them all myself. Tensorflow has a design philosophy of allowing a developer to gradually opt-in to its more low-level APIs. As @HyeonPhilYoun notes in his comment below, the official documentation for tf.keras.callbacks.Callback gives an example of what we're looking for.

The following has worked for me, but can be improved by reverse engineering tf.keras.Model.

The trick is to use tf.keras.callbacks.CallbackList and then manually trigger its lifecycle events from within your custom training loop. This example uses tqdm to give attractive progress bars, but CallbackList has a progress_bar initialization argument that can let you use the defaults. training_model is a typical instance of tf.keras.Model.

from tqdm.notebook import tqdm, trange

# Populate with typical keras callbacks
_callbacks = []

callbacks = tf.keras.callbacks.CallbackList(
    _callbacks, add_history=True, model=training_model)

logs = {}
callbacks.on_train_begin(logs=logs)

# Presentation
epochs = trange(
    max_epochs,
    desc="Epoch",
    unit="Epoch",
    postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}")
epochs.set_postfix(loss=0, accuracy=0)

# Get a stable test set so epoch results are comparable
test_batches = batches(test_x, test_Y)

for epoch in epochs:
    callbacks.on_epoch_begin(epoch, logs=logs)

    # I like to formulate new batches each epoch
    # if there are data augmentation methods in play
    training_batches = batches(x, Y)

    # Presentation
    enumerated_batches = tqdm(
        enumerate(training_batches),
        desc="Batch",
        unit="batch",
        postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}",
        position=1,
        leave=False)

    for (batch, (x, y)) in enumerated_batches:
        training_model.reset_states()
        
        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_train_batch_begin(batch, logs=logs)
        
        logs = training_model.train_on_batch(x=x, y=Y, return_dict=True)

        callbacks.on_train_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)

        # Presentation
        enumerated_batches.set_postfix(
            loss=float(logs["loss"]),
            accuracy=float(logs["accuracy"]))

    for (batch, (x, y)) in enumerate(test_batches):
        training_model.reset_states()

        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_test_batch_begin(batch, logs=logs)

        logs = training_model.test_on_batch(x=x, y=Y, return_dict=True)

        callbacks.on_test_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)

    # Presentation
    epochs.set_postfix(
        loss=float(logs["loss"]),
        accuracy=float(logs["accuracy"]))

    callbacks.on_epoch_end(epoch, logs=logs)

    # NOTE: This is a decent place to check on your early stopping
    # callback.
    # Example: use training_model.stop_training to check for early stopping


callbacks.on_train_end(logs=logs)

# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
    if isinstance(cb, tf.keras.callbacks.History):
        history_object = cb
assert history_object is not None

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...