Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
179 views
in Technique[技术] by (71.8m points)

python - How to define a map_func for tf.data.Dataset.map that can return empty result

I'm using the Dataset APIs in TF 2.4 . Currently I have a working code piece like

def map_func(a:int, b:int) -> typing.Tuple[typing.List[float],typing.List[int]]:
    # some complex logics here, for example, protobuf message deserialization
    return [0.0],[0] if some_condition() else [1.0],[1]

some_dataset 
  .map(lambda a, b: tf.numpy_function(map_func, inp=[a,b], Tout=(tf.float32, tf.int32))) 
  .filter(lambda features, labels: any(labels))  # filter out results whose labels are all zeros, regardless whatever features are
  .some_other_apis()

The map_func function defined above return a tuple of (features, labels) , where labels might contain zeros or non-zeros. By chaining a filter call, I filter out samples whose labels are all 0s.

What's the problem
I'm wondering if it is possible to "integrate" the filter logic inside the map_func, because the current implementation looks somehow ugly and redundant. I tried to return a tuple of ([],[]) or (None, None) when I want to abandon the results, but TF would complain return types mismatching.

question from:https://stackoverflow.com/questions/65842496/how-to-define-a-map-func-for-tf-data-dataset-map-that-can-return-empty-result

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You can use tf.where and tf.gather:

import tensorflow as tf
import numpy as np

def map_func(a) :
    return tf.gather_nd(a, tf.where(a > 0.5))

inputs = np.random.rand(10, 5)

np.round(inputs, 3)
array([[0.952, 0.329, 0.786, 0.714, 0.819],
       [0.048, 0.98 , 0.363, 0.03 , 0.078],
       [0.779, 0.833, 0.368, 0.216, 0.669],
       [0.807, 0.332, 0.217, 0.594, 0.254],
       [0.787, 0.453, 0.943, 0.915, 0.76 ],
       [0.047, 0.014, 0.555, 0.57 , 0.422],
       [0.195, 0.167, 0.077, 0.562, 0.586],
       [0.693, 0.434, 0.055, 0.213, 0.021],
       [0.459, 0.34 , 0.785, 0.938, 0.979],
       [0.08 , 0.667, 0.781, 0.092, 0.644]])
ds = tf.data.Dataset.from_tensor_slices(inputs)

ds = ds.map(map_func) 

for i in ds:
    print(np.round(i.numpy(), 3))
[0.952 0.786 0.714 0.819]
[0.98]
[0.779 0.833 0.669]
[0.807 0.594]
[0.787 0.943 0.915 0.76 ]
[0.555 0.57 ]
[0.562 0.586]
[0.693]
[0.785 0.938 0.979]
[0.667 0.781 0.644]

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...