While training my model I ran into the issue described in the post Tensorflow - Keras: Consider either turning off auto-sharding or switching the auto_shard_policy to DATA to shard this dataset. My question now is: Does the solution mentioned by @Graham501617 work with generators as well? Here is some dummy code for what I use so far:
class BatchGenerator(Sequence): def __init__(self, some_args): ... def __len__(self): num_batches_in_sequence = ... def __getitem__(self, _): data, labels = get_one_batch(self.some_args) return data, labels
In the main script I do something like:
train_generator = BatchGenerator(some_args) valid_generator = BatchGenerator(some_args) cross_device_ops = tf.distribute.HierarchicalCopyAllReduce(num_packs=2) strategy = tf.distribute.MirroredStrategy(cross_device_ops=cross_device_ops) with strategy.scope(): model = some_model model.compile(some_args) history = model.fit( x=train_generator, validation_data=valid_generator, ... )
I would probably have to modify the __getitem__ function somehow, do I?
__getitem__
I appreciate your support!
1.4m articles
1.4m replys
5 comments
57.0k users