Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
302 views
in Technique[技术] by (71.8m points)

python - Loading two models from Saver in the same Tensorflow session

I have two networks: a Model which generates output and an Adversary which grades the output.

Both have been trained separately but now I need to combine their outputs during a single session.

I've attempted to implement the solution proposed in this post: Run multiple pre-trained Tensorflow nets at the same time

My code

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

The problem

The call to the function model_saver.restore() appears to be doing nothing. In another module I use a saver with tf.train.Saver(tf.global_variables()) and it restores the checkpoint fine.

The model has model.tvars = tf.trainable_variables(). To check what was happening I used sess.run() to extract the tvars before and after restore. Each time the initial randomly assigned variables are being used and the variables from the checkpoint are not being assigned.

Any thoughts on why model_saver.restore() appears to be doing nothing?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.

To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read

Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.

First I created two graphs

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

Then two sessions

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

Then I initialised the variables in each session and restored each graph separately

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

From here whenever each session was needed I would wrap any tf functions in that session with with sess.as_default():. At the end I manually close the sessions

sess.close()
adv_sess.close()

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...