Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
407 views
in Technique[技术] by (71.8m points)

python - Using op inputs when defining custom gradients in TensorFlow

I'm trying to define a gradient method for my custom TF operation. Most of the solutions I have found online seem to based on a gist by harpone. I'm reluctant to use that approach as it uses py_func which won't run on GPU. I found another solution here that uses tf.identity() that looks more elegant and I think will run on GPU. However, I have some problems accessing inputs of the ops in my custom gradient function. Here's my code:

@tf.RegisterGradient('MyCustomGradient')
def _custom_gradient(op, gradients):
    x = op.inputs[0]
    return(x)

def my_op(w):
    return tf.pow(w,3)


var_foo = tf.Variable(5, dtype=tf.float32)
bar = my_op(var_foo)


g = tf.get_default_graph()
with g.gradient_override_map({'Identity': 'MyCustomGradient'}):
    bar = tf.identity(bar)
    g = tf.gradients(bar, var_foo)

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())
    print(sess.run(g))

I was expecting _custom_gradient() to return the input to the op (5 in this example) but instead it seems to return op output x gradient. My custom my_op will have non-differentiable operations like tf.sign and I'd like to define my custom gradient based on the inputs. What am I doing wrong?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

There is no problem with your code:

Let's first do the forward pass:

var_foo = 5 -> bar = 125 -> tf.identity(bar) = 125

Now let's backpropagate:

The gradient of tf.identity(bar) with respect to its argument bar equals (by your definition) to bar, that is, 125. The gradient of bar with respect to var_foo equals 3 times the square of var_foo which is 75. Multiply, and you get 9375, which is indeed the output of your code.

op.inputs[0] contains the forward-pass value of the op. In this case, the forward pass of the identity op is 125.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...