File tree Expand file tree Collapse file tree
tensorflow_examples/lite/model_maker/core/data_util Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -63,7 +63,8 @@ def gen_dataset(self,
6363 is_training = False ,
6464 shuffle = False ,
6565 input_pipeline_context = None ,
66- preprocess = None ):
66+ preprocess = None ,
67+ drop_remainder = False ):
6768 """Generate a shared and batched tf.data.Dataset for training/evaluation.
6869
6970 Args:
@@ -76,6 +77,7 @@ def gen_dataset(self,
7677 among multiple workers when distribution strategy is used.
7778 preprocess: A function taking three arguments in order, feature, label and
7879 boolean is_training.
80+ drop_remainder: boolean, whether the finaly batch drops remainder.
7981
8082 Returns:
8183 A TF dataset ready to be consumed by Keras model.
@@ -100,7 +102,7 @@ def gen_dataset(self,
100102 ds = ds .shuffle (buffer_size = min (self ._size , buffer_size ))
101103 ds = ds .repeat ()
102104
103- ds = ds .batch (batch_size )
105+ ds = ds .batch (batch_size , drop_remainder = drop_remainder )
104106 ds = ds .prefetch (tf .data .experimental .AUTOTUNE )
105107 # TODO(b/171449557): Consider converting ds to distributed ds here.
106108 return ds
You can’t perform that action at this time.
0 commit comments