@@ -160,22 +160,44 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
160160 """Reads a dataset from tfds."""
161161 # No op if exist.
162162 tfds_builder .download_and_prepare ()
163-
164- read_config = tfds .ReadConfig (
165- interleave_cycle_length = cycle_length ,
166- interleave_block_length = block_length ,
167- input_context = input_context ,
168- shuffle_seed = seed )
169163 decoders = {}
170164 if tfds_skip_decoding_feature :
171165 for skip_feature in tfds_skip_decoding_feature .split (',' ):
172166 decoders [skip_feature .strip ()] = tfds .decode .SkipDecoding ()
173- dataset = tfds_builder .as_dataset (
174- split = tfds_split ,
175- shuffle_files = is_training ,
176- as_supervised = tfds_as_supervised ,
177- decoders = decoders ,
178- read_config = read_config )
167+ if tfds_builder .info .splits :
168+ num_shards = len (tfds_builder .info .splits [tfds_split ].file_instructions )
169+ else :
170+ # The tfds mock path often does not provide splits.
171+ num_shards = 1
172+ if input_context and num_shards < input_context .num_input_pipelines :
173+ # The number of files in the dataset split is smaller than the number of
174+ # input pipelines. We read the entire dataset first and then shard in the
175+ # host memory.
176+ read_config = tfds .ReadConfig (
177+ interleave_cycle_length = cycle_length ,
178+ interleave_block_length = block_length ,
179+ input_context = None ,
180+ shuffle_seed = seed )
181+ dataset = tfds_builder .as_dataset (
182+ split = tfds_split ,
183+ shuffle_files = is_training ,
184+ as_supervised = tfds_as_supervised ,
185+ decoders = decoders ,
186+ read_config = read_config )
187+ dataset = dataset .shard (input_context .num_input_pipelines ,
188+ input_context .input_pipeline_id )
189+ else :
190+ read_config = tfds .ReadConfig (
191+ interleave_cycle_length = cycle_length ,
192+ interleave_block_length = block_length ,
193+ input_context = input_context ,
194+ shuffle_seed = seed )
195+ dataset = tfds_builder .as_dataset (
196+ split = tfds_split ,
197+ shuffle_files = is_training ,
198+ as_supervised = tfds_as_supervised ,
199+ decoders = decoders ,
200+ read_config = read_config )
179201
180202 if is_training and not cache :
181203 dataset = dataset .repeat ()
0 commit comments