Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.2k views
in Technique[技术] by (71.8m points)

keras - How to enable multi-process in Tensorflow 2.x?

I used to familiar with the Dataset class in pytorch, when calling torch.utils.data.DataLoader(), you can set num_workers= xxx to enable multi-process to load the data(according to this https://discuss.pytorch.org/t/how-to-choose-the-value-of-the-num-workers-of-dataloader/53965/3, num_workers should refer to number of process, and I do see multiple processes showed up in "top" command, correct me if I am wrong)

But when it comes to Tensorflow 2.x, the way to write code seems to be a little different. Here is my scenario: I am doing an image classification task, so I need to do some data augmentation. In pytorch, I would write some flip, crop operation in getitem() function. According to this, Tensorflow 2.0 dataset and dataloader I should write a function doing augmentation job then pass to dataset.map(), following is what I've achieved so far:

def load_image(image_path, label):
    # image_path is just a string, something like 'dataset/training_dataset/xxxx.jpg'
    # label is an integer
    image = tf.io.decode_jpeg(tf.io.read_file(image_path), channels = 3, dct_method='INTEGER_ACCURATE')
    
    # 1. do crop
    height, width, _ = image.shape
    crop_x = np.random.randint(0, width - 224)
    crop_y = np.random.randint(0, height - 224)
    image = tf.image.crop_to_bounding_box(image, crop_y, crop_x, 224, 224)

    # 2. do flip
    if np.random.rand() < 0.5:
        image = tf.image.flip_left_right(image)

    image = image / 127.5
    image -= 1

    label_one_hot = tf.one_hot(label, 50)

    return image, label_one_hot

def load_list(text_file):
    images, labels = [], []
    # read text file,  file name/label, something like
    # dkakd.jpg  24
    # adkjakd.jpg  12
    # and return 
    # 'dataset/training_dataset/dkakd.jpg', 24
    return images, labels

def get_train_dataset(text_file):
    images, labels = load_list(text_file)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.shuffle(len(images))
    dataset = dataset.map(lambda x, y: tf.py_function(load_image, inp=[x, y], Tout=[tf.float32, tf.float32]), num_parallel_calls= tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(128)
    dataset = dataset.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)
    return dataset


# I have two GPUs so I did this according to TF official tutorial
strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.ReductionToOneDevice())
train_dataset = get_train_dataset(text_file)
model = mobilenet_v2() # I'm just using official structure of mobilenet v2
with strategy.scope():
    model.compile()  

model.fit(train_dataset,
          epochs = 15,
          workers = 4,
          use_multiprocessing = True)

According to this Parallelism isn't reducing the time in dataset map num_parallel_calls seems only enable multi-threads, which only helps when you are using tf.ops instead of customize functions, and I also set workers and use_multiprocessing for fit(), but none of them works. So how am I supposed to do to enable multi-process to speed up training? Thanks much in advance.

question from:https://stackoverflow.com/questions/65948484/how-to-enable-multi-process-in-tensorflow-2-x

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

57.0k users

...