Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
151 views
in Technique[技术] by (71.8m points)

python - Restore subset of variables in Tensorflow

I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.

self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)

...

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) 
                      .minimize(self.d_loss, var_list=self.d_vars)

The problem is that I train one of the networks (g) first, and then, I want to train g and d together. However, when I call the load function:

self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()

self.load(self.checkpoint_dir)

def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        return True
    else:
        return False

I have an error like this (with a lot more traceback):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000

I can restore the g network and keep training with that function, but when I want to star d from scratch, and g from the the stored model I have that error.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

To restore a subset of variables, you must create a new tf.train.Saver and pass it a specific list of variables to restore in the optional var_list argument.

By default, a tf.train.Saver will create ops that (i) save every variable in your graph when you call saver.save() and (ii) lookup (by name) every variable in the given checkpoint when you call saver.restore(). While this works for most common scenarios, you have to provide more information to work with specific subsets of the variables:

  1. If you only want to restore a subset of the variables, you can get a list of these variables by calling tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX), assuming that you put the "g" network in a common with tf.name_scope(G_NETWORK_PREFIX): or tf.variable_scope(G_NETWORK_PREFIX): block. You can then pass this list to the tf.train.Saver constructor.

  2. If you want to restore a subset of the variable and/or they variables in the checkpoint have different names, you can pass a dictionary as the var_list argument. By default, each variable in a checkpoint is associated with a key, which is the value of its tf.Variable.name property. If the name is different in the target graph (e.g. because you added a scope prefix), you can specify a dictionary that maps string keys (in the checkpoint file) to tf.Variable objects (in the target graph).


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...