Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
138 views
in Technique[技术] by (71.8m points)

python - tensorflow.train.import_meta_graph does not work?

I try to simply save and restore a graph, but the simplest example does not work as expected (this is done using version 0.9.0 or 0.10.0 on Linux 64 without CUDA using python 2.7 or 3.5.2)

First I save the graph like this:

import tensorflow as tf
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

This creates a file "file" that is non-empty and also sets g1 to something that looks like a proper graph definition.

Then I try to restore this graph:

import tensorflow as tf
g=tf.train.import_meta_graph("file")

This works without an error, but does not return anything at all.

Can anyone provide the necessary code to simply just save the graph for "v4" and completely restore it so that running this in a new session will produce the same result?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

To reuse a MetaGraphDef, you will need to record the names of interesting tensors in your original graph. For example, in the first program, set an explicit name argument in the definition of v1, v2 and v4:

v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
# ...
v4 = tf.add(v3, c1, name="v4")

Then, you can use the string names of the tensors in the original graph in your call to sess.run(). For example, the following snippet should work:

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")

sess = tf.Session()
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})

Alternatively, you can use tf.get_default_graph().get_tensor_by_name() to get tf.Tensor objects for the tensors of interest, which you can then pass to sess.run():

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")
g = tf.get_default_graph()

v1 = g.get_tensor_by_name("v1:0")
v2 = g.get_tensor_by_name("v2:0")
v4 = g.get_tensor_by_name("v4:0")

sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})

UPDATE: Based on discussion in the comments, here a the complete example for saving and loading, including saving the variable contents. This illustrates the saving of a variable by doubling the value of variable vx in a separate operation.

Saving:

import tensorflow as tf
v1 = tf.placeholder(tf.float32, name="v1") 
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")

Restoring:

import tensorflow as tf
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)

The bottom line is that, in order to make use of a saved model, you must remember the names of at least some of the nodes (e.g. a training op, an input placeholder, an evaluation tensor, etc.). The MetaGraphDef stores the list of variables that are contained in the model, and helps to restore these from a checkpoint, but you are required to reconstruct the tensors/operations used in training/evaluating the model yourself.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...