Skip to content

Commit

Permalink
Small fix in trainer data iteration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693571495
  • Loading branch information
xingyousong authored and copybara-github committed Nov 6, 2024
1 parent 840260a commit 1960063
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions optformer/embed_then_regress/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,11 @@ class DataConfig(abc.ABC):
per_device_batch_size: int = 4
max_token_length: int = 256

buffer_size: int = 10000

def wrap_ds(
self, ds: tf.data.Dataset, multi_gpu: bool = False
) -> tf.data.Dataset:
"""This should be used at the trainer level."""
ds = self._tokenize_ds(ds)
ds = ds.repeat()
ds = ds.shuffle(buffer_size=self.buffer_size)

ds = ds.batch(self.per_device_batch_size, drop_remainder=True)
if multi_gpu: # Device count leading dimension, required by jax.pmap.
ds = ds.batch(jax.local_device_count(), drop_remainder=True)
Expand Down

0 comments on commit 1960063

Please sign in to comment.