@@ -179,19 +179,22 @@ class DataConfig(abc.ABC):
179
179
per_device_batch_size : int = 4
180
180
max_token_length : int = 256
181
181
182
- buffer_size : int = 10000
183
-
184
182
def wrap_ds (
185
183
self , ds : tf .data .Dataset , multi_gpu : bool = False
186
184
) -> tf .data .Dataset :
187
185
"""This should be used at the trainer level."""
188
186
ds = self ._tokenize_ds (ds )
189
- ds = ds .repeat ()
190
- ds = ds .shuffle (buffer_size = self .buffer_size )
191
-
192
- ds = ds .batch (self .per_device_batch_size , drop_remainder = True )
187
+ ds = ds .batch (
188
+ self .per_device_batch_size ,
189
+ drop_remainder = True ,
190
+ num_parallel_calls = tf .data .AUTOTUNE ,
191
+ )
193
192
if multi_gpu : # Device count leading dimension, required by jax.pmap.
194
- ds = ds .batch (jax .local_device_count (), drop_remainder = True )
193
+ ds = ds .batch (
194
+ jax .local_device_count (),
195
+ drop_remainder = True ,
196
+ num_parallel_calls = tf .data .AUTOTUNE ,
197
+ )
195
198
ds = ds .prefetch (buffer_size = tf .data .AUTOTUNE )
196
199
return ds
197
200
@@ -217,10 +220,11 @@ def _tokenize_ds(self, ds: tf.data.Dataset) -> tf.data.Dataset:
217
220
transpose_x_only = lambda d : {
218
221
k : tf .transpose (v .to_tensor ()) if k == 'x' else v for k , v in d .items ()
219
222
}
220
- ds = ds .map (transpose_x_only )
223
+ ds = ds .map (transpose_x_only , num_parallel_calls = tf . data . AUTOTUNE )
221
224
ds = seqio .trim_and_pad_dataset (ds , feature_lengths )
222
225
ds = ds .map (
223
- lambda d : {k : tf .transpose (v ) if k == 'x' else v for k , v in d .items ()}
226
+ lambda d : {k : tf .transpose (v ) if k == 'x' else v for k , v in d .items ()},
227
+ num_parallel_calls = tf .data .AUTOTUNE ,
224
228
)
225
229
return ds
226
230
0 commit comments