I've been trying to display images, segmentations and the predicted segmentations in Tensorboard during training, without success. I'm using TensorFlow 2+.
class ImageHistory(tf.keras.callbacks.Callback):
def __init__(self, tensorboard_dir, data, draw_interval=100, num_images_to_show=3):
super(ImageHistory, self).__init__()
self.data = data
self.draw_interval = draw_interval
self.tensorboard_dir = tensorboard_dir
self.num_image_to_show = num_images_to_show
def on_train_batch_end(self, batch, logs={}):
if batch % self.draw_interval == 0:
recap_images = []
for batch_imgseg in self.data.take(self.num_image_to_show):
batch_pred = self.model.predict(batch_imgseg)
# Get `best` 2D slices from batch and images
img2d, seg2d, pred2d = brightest_imgseg_pair_2D(batch_imgseg[0], batch_imgseg[1], batch_pred)
# Display them in a grid
figure = image_grid(img2d, seg2d, pred2d)
figure.savefig(f'logs/images/fig_{batch}.png')
# Transforms figure into Tensor
recap_image = plot_to_image(figure)
recap_images.append(recap_image)
recap_images = np.reshape(recap_images, (-1, 288, 432, 4))
writer = tf.summary.create_file_writer(str(self.tensorboard_dir))
with writer.as_default():
tf.summary.image("Images and segmentations", recap_images, max_outputs=len(recap_images), step=batch)
This class is called like this (where train_data
a tf.data.Dataset
)
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=logDir)
image_history_callback = ImageHistory(tensorboard_dir=logDir/'images', data=train_data, draw_interval=10, num_images_to_show=2)
model_history = model.fit(train_data,
callbacks=[tb_callback, image_history_callback])
I'm using some of Tensorflow boilerplate code above (plot_to_image).
I added the line figure.savefig(f'logs/images/fig_{batch}.png')
to start troubleshooting: my display images are being generated correctly.
Also, the same code works if I don't use it during the training -- meaning if I load my dataset (the same way I do before calling model.fit(...)
), take batches out of it and run what's inside the for batch_imgseg
loop.
I'm wondering if the way to call the file_writer
is different between a Callback
vs. in a notebook?
EDIT: I printed out the result from the tf.summary.image()
, it returns False. According to TF docs:
True on success, or false if no summary was emitted because no default summary writer was available.
So it is an issue with the file_writer
as suspected. Continuing to debug...
question from:
https://stackoverflow.com/questions/65853340/custom-tf-keras-callback-to-display-image-and-predicted-segmentation-not-showing