Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
228 views
in Technique[技术] by (71.8m points)

python - Passing a numpy array to a tensorflow Queue

I have a NumPy array and would like to read it in TensorFlow's code using a Queue. I would like the queue to return the whole data shuffled, some specified number of epochs and throw an error after that. It would be best if I'd not need to hardcode the size of an example nor the number of examples. I think shuffle batch is meant to serve that purpose. I have tried using it as follows:

data = tf.constant(train_np) # train_np is my numpy array of shape (num_examples, example_size)
batch = tf.train.shuffle_batch([data], batch_size=5, capacity=52200, min_after_dequeue=10, num_threads=1, seed=None, enqueue_many=True)

sess.run(tf.initialize_all_variables())
tf.train.start_queue_runners(sess=sess)
batch.eval()

The problem with that approach is that it reads all the data continuously and I cannot specify it to finish after some number of epochs. I am aware I could use the RandomShuffleQueue and insert the data into it few times, but: a) I don't want to waste epoch*data of memory and b) it will allow the queue to shuffle between epochs.

Is there a nice way to read the shuffled data in epochs in Tensorflow without writing your own Queue?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You could create another queue, enqueue your data onto it num_epoch times, close it, and then hook it up to your batch. To save memory, you can make this queue small, and enqueue items onto it in parallel. There will be a bit of mixing between epochs. To fully prevent mixing, you could take code below with num_epochs=1 and call it num_epochs times.

tf.reset_default_graph()
data = np.array([1, 2, 3, 4])
num_epochs = 5
queue1_input = tf.placeholder(tf.int32)
queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.int32], shapes=[()])

def create_session():
    config = tf.ConfigProto()
    config.operation_timeout_in_ms=20000
    return tf.InteractiveSession(config=config)

enqueue_op = queue1.enqueue_many(queue1_input)
close_op = queue1.close()
dequeue_op = queue1.dequeue()
batch = tf.train.shuffle_batch([dequeue_op], batch_size=4, capacity=5, min_after_dequeue=4)

sess = create_session()

def fill_queue():
    for i in range(num_epochs):
        sess.run(enqueue_op, feed_dict={queue1_input: data})
    sess.run(close_op)

fill_thread = threading.Thread(target=fill_queue, args=())
fill_thread.start()

# read the data from queue shuffled
tf.train.start_queue_runners()
try:
    while True:
        print batch.eval()
except tf.errors.OutOfRangeError:
    print "Done"

BTW, enqueue_many pattern above will hang when the queue is not large enough to load the entire numpy dataset into it. You could give yourself flexibility to have a smaller queue by loading the data in chunks as below.

tf.reset_default_graph()
data = np.array([1, 2, 3, 4])
queue1_capacity = 2
num_epochs = 2
queue1_input = tf.placeholder(tf.int32)
queue1 = tf.FIFOQueue(capacity=queue1_capacity, dtypes=[tf.int32], shapes=[()])

enqueue_op = queue1.enqueue_many(queue1_input)
close_op = queue1.close()
dequeue_op = queue1.dequeue()

def dequeue():
    try:
        while True:
            print sess.run(dequeue_op)
    except:
        return 

def enqueue():
    for i in range(num_epochs):
        start_pos = 0
        while start_pos < len(data):
            end_pos = start_pos+queue1_capacity
            data_chunk = data[start_pos: end_pos]
            sess.run(enqueue_op, feed_dict={queue1_input: data_chunk})
            start_pos += queue1_capacity
    sess.run(close_op)

sess = create_session()

enqueue_thread = threading.Thread(target=enqueue, args=())
enqueue_thread.start()

dequeue_thread = threading.Thread(target=dequeue, args=())
dequeue_thread.start()

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...