Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
394 views
in Technique[技术] by (71.8m points)

python - Keras Visualization of Model Built from Functional API

I wanted to ask if there was an easy way to visualize a Keras model built from the Functional API?

Right now, the best ways to debug at a high level a sequential model for me is:

model = Sequential()
model.add(...
...

print(model.summary())
SVG(model_to_dot(model).create(prog='dot', format='svg'))

However, I am having a hard time finding a good way to visualize the Keras API if we build a more complex, non-sequential model.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Yes there is, try checking the keras.utils which has a method plot_model() as explained on detail here. Seems that you already are familiar with keras.utils.vis_utils and the model_to_dot method, but this is another option. It's usage is something like:

from keras.utils import plot_model
plot_model(model, to_file='model.png')

To be honest, that is the best I have managed to find using Keras only. Using model.summary() as you did is also useful sometimes. I also wished there were some tool to enable for better visualization of one's models, perhaps even to be able to see the weights per layers as to decide on optimal network structures and initializations (if you know about one please tell :] ).


Probably the best option you currently have is to visualize things on Tensorboard, which you an include in Keras with the TensorBoard Callback. This enables you to visualize your training and the metrics of interest, as well as some info on activations of your layers,your biases and kernels, etc.. Basically you have to add this code to your program, before fitting your model:

from keras.callbacks import TensorBoard
#indicate folder to save, plus other options
tensorboard = TensorBoard(log_dir='./logs/run1', histogram_freq=1,
    write_graph=True, write_images=False)  

#save it in your callback list, where you can include other callbacks
callbacks_list = [tensorboard]
#then pass to fit as callback, remember to use validation_data also
regressor.fit(X, Y, callbacks=callbacks_list, epochs=64, 
    validation_data=(X_test, Y_test), shuffle=True)

You can then run Tensorboard (which runs locally on a webservice) with the following command on your terminal:

tensorboard --logdir=/logs/run1

This will then indicate you in which port to visualize your training. If you got different runs you can pass --logdir=/logs instead to be able to visualize them together for comparison. There are of course more options on the use of Tensorboard, so I suggest you check the included links if you are considering its use.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...