Both current answers kinda work by filtering the variable name using the 'Momentum' string. But that is very brittle on two sides:
- It could silently (re-)initialize some other variables you don't actually want to reset! Either simply because of a name-clash, or because you have a more complex graph and optimize different parts separately, for example.
- It will only work for one specific optimizer, and how do you know the names to look out for for others?
- Bonus: an update to tensorflow might silently break your code.
Fortunately, tensorflow's abstract Optimizer
class has a mechanism for that, these extra optimizer variables are called "slots", and you can get all slot names of an optimizer using the get_slot_names()
method:
opt = tf.train.MomentumOptimizer(...)
print(opt.get_slot_names())
# prints ['momentum']
And you can get the variable corresponding to the slot for a specific (trainable) variable v
using the get_slot(var, slot_name)
method:
opt.get_slot(some_var, 'momentum')
Putting all this together, you can create an op that initializes the optimizer's state as follows:
var_list = # list of vars to optimize, e.g.
# tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
opt = tf.train.MomentumOptimizer(0.1, 0.95)
step_op = opt.minimize(loss, var_list=var_list)
reset_opt_op = tf.variables_initializer([opt.get_slot(var, name) for name in opt.get_slot_names() for var in var_list])
This will really only reset the correct variables, and be robust across optimizers.
Except for one unfortunate caveat: AdamOptimizer
. That one also keeps a counter for how often it's been called. That means you should really think hard about what you're doing here anyways, but for completeness' sake, you can get its extra states as opt._get_beta_accumulators()
. The returned list should be added to the list in the above reset_opt_op
line.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…