diff --git a/optformer/embed_then_regress/configs.py b/optformer/embed_then_regress/configs.py index 4626d53..377af49 100644 --- a/optformer/embed_then_regress/configs.py +++ b/optformer/embed_then_regress/configs.py @@ -179,16 +179,11 @@ class DataConfig(abc.ABC): per_device_batch_size: int = 4 max_token_length: int = 256 - buffer_size: int = 10000 - def wrap_ds( self, ds: tf.data.Dataset, multi_gpu: bool = False ) -> tf.data.Dataset: """This should be used at the trainer level.""" ds = self._tokenize_ds(ds) - ds = ds.repeat() - ds = ds.shuffle(buffer_size=self.buffer_size) - ds = ds.batch(self.per_device_batch_size, drop_remainder=True) if multi_gpu: # Device count leading dimension, required by jax.pmap. ds = ds.batch(jax.local_device_count(), drop_remainder=True)