Skip to content

Commit 49c8f0b

Browse files
lintian06copybara-github
authored andcommitted
For data loader, allow drop_remainer to be set.
PiperOrigin-RevId: 351491512
1 parent fa831d3 commit 49c8f0b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorflow_examples/lite/model_maker/core/data_util/dataloader.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def gen_dataset(self,
6363
is_training=False,
6464
shuffle=False,
6565
input_pipeline_context=None,
66-
preprocess=None):
66+
preprocess=None,
67+
drop_remainder=False):
6768
"""Generate a shared and batched tf.data.Dataset for training/evaluation.
6869
6970
Args:
@@ -76,6 +77,7 @@ def gen_dataset(self,
7677
among multiple workers when distribution strategy is used.
7778
preprocess: A function taking three arguments in order, feature, label and
7879
boolean is_training.
80+
drop_remainder: boolean, whether the finaly batch drops remainder.
7981
8082
Returns:
8183
A TF dataset ready to be consumed by Keras model.
@@ -100,7 +102,7 @@ def gen_dataset(self,
100102
ds = ds.shuffle(buffer_size=min(self._size, buffer_size))
101103
ds = ds.repeat()
102104

103-
ds = ds.batch(batch_size)
105+
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
104106
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
105107
# TODO(b/171449557): Consider converting ds to distributed ds here.
106108
return ds

0 commit comments

Comments
 (0)