Skip to content

Commit 5d528d9

Browse files
xingyousongcopybara-github
authored andcommitted
Small fix in trainer data iteration
PiperOrigin-RevId: 693753499
1 parent 69d1ed1 commit 5d528d9

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

optformer/embed_then_regress/configs.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,22 @@ class DataConfig(abc.ABC):
179179
per_device_batch_size: int = 4
180180
max_token_length: int = 256
181181

182-
buffer_size: int = 10000
183-
184182
def wrap_ds(
185183
self, ds: tf.data.Dataset, multi_gpu: bool = False
186184
) -> tf.data.Dataset:
187185
"""This should be used at the trainer level."""
188186
ds = self._tokenize_ds(ds)
189-
ds = ds.repeat()
190-
ds = ds.shuffle(buffer_size=self.buffer_size)
191-
192-
ds = ds.batch(self.per_device_batch_size, drop_remainder=True)
187+
ds = ds.batch(
188+
self.per_device_batch_size,
189+
drop_remainder=True,
190+
num_parallel_calls=tf.data.AUTOTUNE,
191+
)
193192
if multi_gpu: # Device count leading dimension, required by jax.pmap.
194-
ds = ds.batch(jax.local_device_count(), drop_remainder=True)
193+
ds = ds.batch(
194+
jax.local_device_count(),
195+
drop_remainder=True,
196+
num_parallel_calls=tf.data.AUTOTUNE,
197+
)
195198
ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
196199
return ds
197200

@@ -217,10 +220,11 @@ def _tokenize_ds(self, ds: tf.data.Dataset) -> tf.data.Dataset:
217220
transpose_x_only = lambda d: {
218221
k: tf.transpose(v.to_tensor()) if k == 'x' else v for k, v in d.items()
219222
}
220-
ds = ds.map(transpose_x_only)
223+
ds = ds.map(transpose_x_only, num_parallel_calls=tf.data.AUTOTUNE)
221224
ds = seqio.trim_and_pad_dataset(ds, feature_lengths)
222225
ds = ds.map(
223-
lambda d: {k: tf.transpose(v) if k == 'x' else v for k, v in d.items()}
226+
lambda d: {k: tf.transpose(v) if k == 'x' else v for k, v in d.items()},
227+
num_parallel_calls=tf.data.AUTOTUNE,
224228
)
225229
return ds
226230

0 commit comments

Comments
 (0)