You can visualize the graph of any tf.function
decorated function, but first, you have to trace its execution.
Visualizing the graph of a Keras model means to visualize it's call
method.
By default, this method is not tf.function
decorated and therefore you have to wrap the model call in a function correctly decorated and execute it.
import tensorflow as tf
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(32, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
@tf.function
def traceme(x):
return model(x)
logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…