Skip to content

Commit aac3354

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 437358569
1 parent b6fcc07 commit aac3354

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

official/core/input_reader.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,44 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
160160
"""Reads a dataset from tfds."""
161161
# No op if exist.
162162
tfds_builder.download_and_prepare()
163-
164-
read_config = tfds.ReadConfig(
165-
interleave_cycle_length=cycle_length,
166-
interleave_block_length=block_length,
167-
input_context=input_context,
168-
shuffle_seed=seed)
169163
decoders = {}
170164
if tfds_skip_decoding_feature:
171165
for skip_feature in tfds_skip_decoding_feature.split(','):
172166
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
173-
dataset = tfds_builder.as_dataset(
174-
split=tfds_split,
175-
shuffle_files=is_training,
176-
as_supervised=tfds_as_supervised,
177-
decoders=decoders,
178-
read_config=read_config)
167+
if tfds_builder.info.splits:
168+
num_shards = len(tfds_builder.info.splits[tfds_split].file_instructions)
169+
else:
170+
# The tfds mock path often does not provide splits.
171+
num_shards = 1
172+
if input_context and num_shards < input_context.num_input_pipelines:
173+
# The number of files in the dataset split is smaller than the number of
174+
# input pipelines. We read the entire dataset first and then shard in the
175+
# host memory.
176+
read_config = tfds.ReadConfig(
177+
interleave_cycle_length=cycle_length,
178+
interleave_block_length=block_length,
179+
input_context=None,
180+
shuffle_seed=seed)
181+
dataset = tfds_builder.as_dataset(
182+
split=tfds_split,
183+
shuffle_files=is_training,
184+
as_supervised=tfds_as_supervised,
185+
decoders=decoders,
186+
read_config=read_config)
187+
dataset = dataset.shard(input_context.num_input_pipelines,
188+
input_context.input_pipeline_id)
189+
else:
190+
read_config = tfds.ReadConfig(
191+
interleave_cycle_length=cycle_length,
192+
interleave_block_length=block_length,
193+
input_context=input_context,
194+
shuffle_seed=seed)
195+
dataset = tfds_builder.as_dataset(
196+
split=tfds_split,
197+
shuffle_files=is_training,
198+
as_supervised=tfds_as_supervised,
199+
decoders=decoders,
200+
read_config=read_config)
179201

180202
if is_training and not cache:
181203
dataset = dataset.repeat()

0 commit comments

Comments
 (0)