File tree 1 file changed +4
-2
lines changed
tensorflow_examples/lite/model_maker/core/data_util
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -63,7 +63,8 @@ def gen_dataset(self,
63
63
is_training = False ,
64
64
shuffle = False ,
65
65
input_pipeline_context = None ,
66
- preprocess = None ):
66
+ preprocess = None ,
67
+ drop_remainder = False ):
67
68
"""Generate a shared and batched tf.data.Dataset for training/evaluation.
68
69
69
70
Args:
@@ -76,6 +77,7 @@ def gen_dataset(self,
76
77
among multiple workers when distribution strategy is used.
77
78
preprocess: A function taking three arguments in order, feature, label and
78
79
boolean is_training.
80
+ drop_remainder: boolean, whether the finaly batch drops remainder.
79
81
80
82
Returns:
81
83
A TF dataset ready to be consumed by Keras model.
@@ -100,7 +102,7 @@ def gen_dataset(self,
100
102
ds = ds .shuffle (buffer_size = min (self ._size , buffer_size ))
101
103
ds = ds .repeat ()
102
104
103
- ds = ds .batch (batch_size )
105
+ ds = ds .batch (batch_size , drop_remainder = drop_remainder )
104
106
ds = ds .prefetch (tf .data .experimental .AUTOTUNE )
105
107
# TODO(b/171449557): Consider converting ds to distributed ds here.
106
108
return ds
You can’t perform that action at this time.
0 commit comments