You could create another queue, enqueue your data onto it num_epoch
times, close it, and then hook it up to your batch
. To save memory, you can make this queue small, and enqueue items onto it in parallel. There will be a bit of mixing between epochs. To fully prevent mixing, you could take code below with num_epochs=1
and call it num_epochs
times.
tf.reset_default_graph()
data = np.array([1, 2, 3, 4])
num_epochs = 5
queue1_input = tf.placeholder(tf.int32)
queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.int32], shapes=[()])
def create_session():
config = tf.ConfigProto()
config.operation_timeout_in_ms=20000
return tf.InteractiveSession(config=config)
enqueue_op = queue1.enqueue_many(queue1_input)
close_op = queue1.close()
dequeue_op = queue1.dequeue()
batch = tf.train.shuffle_batch([dequeue_op], batch_size=4, capacity=5, min_after_dequeue=4)
sess = create_session()
def fill_queue():
for i in range(num_epochs):
sess.run(enqueue_op, feed_dict={queue1_input: data})
sess.run(close_op)
fill_thread = threading.Thread(target=fill_queue, args=())
fill_thread.start()
# read the data from queue shuffled
tf.train.start_queue_runners()
try:
while True:
print batch.eval()
except tf.errors.OutOfRangeError:
print "Done"
BTW, enqueue_many
pattern above will hang when the queue is not large enough to load the entire numpy dataset into it. You could give yourself flexibility to have a smaller queue by loading the data in chunks as below.
tf.reset_default_graph()
data = np.array([1, 2, 3, 4])
queue1_capacity = 2
num_epochs = 2
queue1_input = tf.placeholder(tf.int32)
queue1 = tf.FIFOQueue(capacity=queue1_capacity, dtypes=[tf.int32], shapes=[()])
enqueue_op = queue1.enqueue_many(queue1_input)
close_op = queue1.close()
dequeue_op = queue1.dequeue()
def dequeue():
try:
while True:
print sess.run(dequeue_op)
except:
return
def enqueue():
for i in range(num_epochs):
start_pos = 0
while start_pos < len(data):
end_pos = start_pos+queue1_capacity
data_chunk = data[start_pos: end_pos]
sess.run(enqueue_op, feed_dict={queue1_input: data_chunk})
start_pos += queue1_capacity
sess.run(close_op)
sess = create_session()
enqueue_thread = threading.Thread(target=enqueue, args=())
enqueue_thread.start()
dequeue_thread = threading.Thread(target=dequeue, args=())
dequeue_thread.start()
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…