Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
152 views
in Technique[技术] by (71.8m points)

python - How to replace the input of a saved graph, e.g. a placeholder by a Dataset iterator?

I have a saved Tensorflow graph that consumes input through a placeholder with a feed_dict param.

sess.run(my_tensor, feed_dict={input_image: image})

Because feeding data with a Dataset Iterator is more efficient, I want to load the saved graph, replace the input_image placeholder with an Iterator and run. How can I do that? Is there a better way to do it? An answer with code example would be highly appreciated.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You can achieve that by serializing your graph and reimport it using tf.import_graph_def, which has an input_map argument used to plug-in inputs at the desired places.

To do that you need at least to know the name of the inputs you replace and of the outputs you wish to execute (resp. x and y in my examples).

import tensorflow as tf

# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')

# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
  print("with placeholder:")
  for i in range(10):
    print(sess.run(y, {x: i}))

# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()

tf.reset_default_graph()

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])

# enjoy Dataset inputs!
with tf.Session() as sess:
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        

Note that the placeholder node is still there as I did not bother here to parse graph_def to remove it -- you could remove it as an improvement, although I think it is also OK to leave it here.

Depending on how you restore your graph, the input replacement may be already built-in in the loader, which makes things simpler (no need to go back to a GraphDef). For example, if you load your graph from a .meta file, you can use tf.train.import_meta_graph which accepts the same input_map argument.

import tensorflow as tf

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')

# enjoy Dataset inputs!
with tf.Session() as sess:
  # not needed here, but in practice you would also need to restore weights
  # restorer.restore(sess, weights_filepath)
  print('with Dataset:')
  try:
    while True:
      print(sess.run(y))
  except tf.errors.OutOfRangeError:
    pass        

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...