Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.
To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read
Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.
First I created two graphs
model_graph = tf.Graph()
with model_graph.as_default():
model = Model(args)
adv_graph = tf.Graph()
with adv_graph.as_default():
adversary = Adversary(adv_args)
Then two sessions
adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)
Then I initialised the variables in each session and restored each graph separately
with sess.as_default():
with model_graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
with adv_sess.as_default():
with adv_graph.as_default():
tf.global_variables_initializer().run()
adv_saver = tf.train.Saver(tf.global_variables())
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
From here whenever each session was needed I would wrap any tf
functions in that session with with sess.as_default():
. At the end I manually close the sessions
sess.close()
adv_sess.close()
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…