I used to familiar with the Dataset class in pytorch, when calling torch.utils.data.DataLoader(), you can set num_workers= xxx to enable multi-process to load the data(according to this https://discuss.pytorch.org/t/how-to-choose-the-value-of-the-num-workers-of-dataloader/53965/3, num_workers should refer to number of process, and I do see multiple processes showed up in "top" command, correct me if I am wrong)
But when it comes to Tensorflow 2.x, the way to write code seems to be a little different.
Here is my scenario:
I am doing an image classification task, so I need to do some data augmentation.
In pytorch, I would write some flip, crop operation in getitem() function.
According to this, Tensorflow 2.0 dataset and dataloader
I should write a function doing augmentation job then pass to dataset.map(), following is what I've achieved so far:
def load_image(image_path, label):
# image_path is just a string, something like 'dataset/training_dataset/xxxx.jpg'
# label is an integer
image = tf.io.decode_jpeg(tf.io.read_file(image_path), channels = 3, dct_method='INTEGER_ACCURATE')
# 1. do crop
height, width, _ = image.shape
crop_x = np.random.randint(0, width - 224)
crop_y = np.random.randint(0, height - 224)
image = tf.image.crop_to_bounding_box(image, crop_y, crop_x, 224, 224)
# 2. do flip
if np.random.rand() < 0.5:
image = tf.image.flip_left_right(image)
image = image / 127.5
image -= 1
label_one_hot = tf.one_hot(label, 50)
return image, label_one_hot
def load_list(text_file):
images, labels = [], []
# read text file, file name/label, something like
# dkakd.jpg 24
# adkjakd.jpg 12
# and return
# 'dataset/training_dataset/dkakd.jpg', 24
return images, labels
def get_train_dataset(text_file):
images, labels = load_list(text_file)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.shuffle(len(images))
dataset = dataset.map(lambda x, y: tf.py_function(load_image, inp=[x, y], Tout=[tf.float32, tf.float32]), num_parallel_calls= tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(128)
dataset = dataset.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)
return dataset
# I have two GPUs so I did this according to TF official tutorial
strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.ReductionToOneDevice())
train_dataset = get_train_dataset(text_file)
model = mobilenet_v2() # I'm just using official structure of mobilenet v2
with strategy.scope():
model.compile()
model.fit(train_dataset,
epochs = 15,
workers = 4,
use_multiprocessing = True)
According to this Parallelism isn't reducing the time in dataset map
num_parallel_calls seems only enable multi-threads, which only helps when you are using tf.ops instead of customize functions, and I also set workers and use_multiprocessing for fit(), but none of them works.
So how am I supposed to do to enable multi-process to speed up training?
Thanks much in advance.
question from:
https://stackoverflow.com/questions/65948484/how-to-enable-multi-process-in-tensorflow-2-x