I'm trying to convert the Iris tutorial (https://www.tensorflow.org/get_started/estimator) to read training data from .png files instead of .csv. It works using numpy_input_fn
but not when I make it from a Dataset
. I think input_fn()
is returning the wrong type but don't really understand what it should be and how to make it that. The error is:
File "iris_minimal.py", line 27, in <module>
model_fn().train(input_fn(), steps=1)
...
raise TypeError('unsupported callable') from ex
TypeError: unsupported callable
TensorFlow version is 1.3. Complete code:
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
NUM_CLASSES = 3
def model_fn():
feature_columns = [tf.feature_column.numeric_column("x", shape=[4])]
return tf.estimator.DNNClassifier([10, 20, 10], feature_columns, "tmp/iris_model", NUM_CLASSES)
def input_parser(img_path, label):
one_hot = tf.one_hot(label, NUM_CLASSES)
file_contents = tf.read_file(img_path)
image_decoded = tf.image.decode_png(file_contents, channels=1)
image_decoded = tf.image.resize_images(image_decoded, [2, 2])
image_decoded = tf.reshape(image_decoded, [4])
return image_decoded, one_hot
def input_fn():
filenames = tf.constant(['images/image_1.png', 'images/image_2.png'])
labels = tf.constant([0,1])
data = Dataset.from_tensor_slices((filenames, labels))
data = data.map(input_parser)
iterator = data.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
model_fn().train(input_fn(), steps=1)
See Question&Answers more detail:
os 与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…