Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
77 views
in Technique[技术] by (71.8m points)

python - Tensorflow tf.function conditionals

It took me some time to pin down the problem. Here is is:

class ComplicatedStuff:
    def __init__(self):
        self.result = None

    def fun(self, val):
        self.result = val

@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()  
    # s.do_this(blabla)
    # s.do_that(blabla)
    if x > .5:
        s.fun(2*x)
    else:
        s.fun(x)
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=1.0>

I would really expect to get 2.0 back instead of 1.0. I figured out the reason is that the conditional is traced in both branches, and because I return a value using a side-effect in s, only the result of the second branch survives. The question is, how do I code around this limitation? Using return values would solve it, but it will definitely uglify the code because ComplicatedStuff wraps a bunch of intermediate results that I don't want to expose like that. Is there some better option?

The thing I came up with that more-or-less preserved the structure, is this hackery:

class ComplicatedStuff(dict):
    def __init__(self):
        super().__init__()
        self.result = None

    def fun(self, val):
        self.result = val
        
    def __setattr__(self, item, value):
        self[item] = value
    
    def __getattribute__(self, item):
        if item.startswith("__") or item not in self:
            return super().__getattribute__(item)
        else:
            return self[item]
        
@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()
    # s.do_this(blabla)
    # s.do_that(blabla)
    if x > .5:
        s.fun(2*x)
        s = s
    else:
        s.fun(x)
        s = s
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

There must be a better option, right?

question from:https://stackoverflow.com/questions/66062501/tensorflow-tf-function-conditionals

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Tensorflow automatically converts some if statements into tf.cond nodes in tf.function. This is called Autograph.

Since that is not working here, we can do that ourselves:

class ComplicatedStuff:
    def __init__(self):
        self.result = None

    def fun(self, val):
        self.result = val

@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()  
    # s.do_this(blabla)
    # s.do_that(blabla)

    x_tmp = tf.cond(x > .5, lambda: 2*x, lambda: x)
    s.fun(x_tmp)
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

no_fun(tf.constant(0.23), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=0.23>

As mentioned in the docs:

tf.cond traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...