@@ -160,22 +160,44 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
160
160
"""Reads a dataset from tfds."""
161
161
# No op if exist.
162
162
tfds_builder .download_and_prepare ()
163
-
164
- read_config = tfds .ReadConfig (
165
- interleave_cycle_length = cycle_length ,
166
- interleave_block_length = block_length ,
167
- input_context = input_context ,
168
- shuffle_seed = seed )
169
163
decoders = {}
170
164
if tfds_skip_decoding_feature :
171
165
for skip_feature in tfds_skip_decoding_feature .split (',' ):
172
166
decoders [skip_feature .strip ()] = tfds .decode .SkipDecoding ()
173
- dataset = tfds_builder .as_dataset (
174
- split = tfds_split ,
175
- shuffle_files = is_training ,
176
- as_supervised = tfds_as_supervised ,
177
- decoders = decoders ,
178
- read_config = read_config )
167
+ if tfds_builder .info .splits :
168
+ num_shards = len (tfds_builder .info .splits [tfds_split ].file_instructions )
169
+ else :
170
+ # The tfds mock path often does not provide splits.
171
+ num_shards = 1
172
+ if input_context and num_shards < input_context .num_input_pipelines :
173
+ # The number of files in the dataset split is smaller than the number of
174
+ # input pipelines. We read the entire dataset first and then shard in the
175
+ # host memory.
176
+ read_config = tfds .ReadConfig (
177
+ interleave_cycle_length = cycle_length ,
178
+ interleave_block_length = block_length ,
179
+ input_context = None ,
180
+ shuffle_seed = seed )
181
+ dataset = tfds_builder .as_dataset (
182
+ split = tfds_split ,
183
+ shuffle_files = is_training ,
184
+ as_supervised = tfds_as_supervised ,
185
+ decoders = decoders ,
186
+ read_config = read_config )
187
+ dataset = dataset .shard (input_context .num_input_pipelines ,
188
+ input_context .input_pipeline_id )
189
+ else :
190
+ read_config = tfds .ReadConfig (
191
+ interleave_cycle_length = cycle_length ,
192
+ interleave_block_length = block_length ,
193
+ input_context = input_context ,
194
+ shuffle_seed = seed )
195
+ dataset = tfds_builder .as_dataset (
196
+ split = tfds_split ,
197
+ shuffle_files = is_training ,
198
+ as_supervised = tfds_as_supervised ,
199
+ decoders = decoders ,
200
+ read_config = read_config )
179
201
180
202
if is_training and not cache :
181
203
dataset = dataset .repeat ()
0 commit comments