Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
688 views
in Technique[技术] by (71.8m points)

tensorflow - Is it possible to make a trainable variable not trainable?

I created a trainable variable in a scope. Later, I entered the same scope, set the scope to reuse_variables, and used get_variable to retrieve the same variable. However, I cannot set the variable's trainable property to False. My get_variable line is like:

weight_var = tf.get_variable('weights', trainable = False)

But the variable 'weights' is still in the output of tf.trainable_variables.

Can I set a shared variable's trainable flag to False by using get_variable?

The reason I want to do this is that I'm trying to reuse the low-level filters pre-trained from VGG net in my model, and I want to build the graph like before, retrieve the weights variable, and assign VGG filter values to the weight variable, and then keep them fixed during the following training step.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

After looking at the documentation and the code, I was not able to find a way to remove a Variable from the TRAINABLE_VARIABLES.

Here is what happens:

  • The first time tf.get_variable('weights', trainable=True) is called, the variable is added to the list of TRAINABLE_VARIABLES.
  • The second time you call tf.get_variable('weights', trainable=False), you get the same variable but the argument trainable=False has no effect as the variable is already present in the list of TRAINABLE_VARIABLES (and there is no way to remove it from there)

First solution

When calling the minimize method of the optimizer (see doc.), you can pass a var_list=[...] as argument with the variables you want to optimizer.

For instance, if you want to freeze all the layers of VGG except the last two, you can pass the weights of the last two layers in var_list.

Second solution

You can use a tf.train.Saver() to save variables and restore them later (see this tutorial).

  • First you train your entire VGG model with all trainable variables. You save them in a checkpoint file by calling saver.save(sess, "/path/to/dir/model.ckpt").
  • Then (in another file) you train the second version with non trainable variables. You load the variables previously stored with saver.restore(sess, "/path/to/dir/model.ckpt").

Optionally, you can decide to save only some of the variables in your checkpoint file. See the doc for more info.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...