Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
260 views
in Technique[技术] by (71.8m points)

python - How to set adaptive learning rate for GradientDescentOptimizer?

I am using TensorFlow to train a neural network. This is how I am initializing the GradientDescentOptimizer:

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

mse        = tf.reduce_mean(tf.square(out - out_))
train_step = tf.train.GradientDescentOptimizer(0.3).minimize(mse)

The thing here is that I don't know how to set an update rule for the learning rate or a decay value for that.

How can I use an adaptive learning rate here?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

First of all, tf.train.GradientDescentOptimizer is designed to use a constant learning rate for all variables in all steps. TensorFlow also provides out-of-the-box adaptive optimizers including the tf.train.AdagradOptimizer and the tf.train.AdamOptimizer, and these can be used as drop-in replacements.

However, if you want to control the learning rate with otherwise-vanilla gradient descent, you can take advantage of the fact that the learning_rate argument to the tf.train.GradientDescentOptimizer constructor can be a Tensor object. This allows you to compute a different value for the learning rate in each step, for example:

learning_rate = tf.placeholder(tf.float32, shape=[])
# ...
train_step = tf.train.GradientDescentOptimizer(
    learning_rate=learning_rate).minimize(mse)

sess = tf.Session()

# Feed different values for learning rate to each training step.
sess.run(train_step, feed_dict={learning_rate: 0.1})
sess.run(train_step, feed_dict={learning_rate: 0.1})
sess.run(train_step, feed_dict={learning_rate: 0.01})
sess.run(train_step, feed_dict={learning_rate: 0.01})

Alternatively, you could create a scalar tf.Variable that holds the learning rate, and assign it each time you want to change the learning rate.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...