Skip to content

Commit d975425

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Read the length of the datasource from the FileInstructions to limit I/O.
PiperOrigin-RevId: 737687954
1 parent 27547b2 commit d975425

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

tensorflow_datasets/core/data_sources/array_record.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ArrayRecordDataSource(base.BaseDataSource):
5656
length: int = dataclasses.field(init=False)
5757

5858
def __post_init__(self):
59-
file_instructions = base.file_instructions(self.dataset_info, self.split)
59+
file_instructions = self.split_info.file_instructions
6060
self.data_source = array_record_data_source.ArrayRecordDataSource(
6161
file_instructions
6262
)

tensorflow_datasets/core/data_sources/base.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,6 @@ def __getitems__(self, keys: Iterable[int]) -> T:
4545
"""Returns the value for the given `keys`."""
4646

4747

48-
def file_instructions(
49-
dataset_info: dataset_info_lib.DatasetInfo,
50-
split: splits_lib.Split | None = None,
51-
) -> list[shard_utils.FileInstruction]:
52-
"""Retrieves the file instructions from the DatasetInfo."""
53-
split_infos = dataset_info.splits.values()
54-
split_dict = splits_lib.SplitDict(split_infos=split_infos)
55-
return split_dict[split].file_instructions
56-
57-
5848
@dataclasses.dataclass
5949
class BaseDataSource(MappingView, Sequence):
6050
"""Base DataSource to override all dunder methods with the deserialization.
@@ -94,6 +84,13 @@ def _deserialize(self, record: Any) -> Any:
9484
return features.deserialize_example_np(record, decoders=self.decoders) # pylint: disable=attribute-error
9585
raise ValueError('No features set, cannot decode example!')
9686

87+
@property
88+
def split_info(self) -> splits_lib.SplitInfo | splits_lib.SubSplitInfo:
89+
"""Returns the SplitInfo for the split."""
90+
split_infos = self.dataset_info.splits.values()
91+
splits_dict = splits_lib.SplitDict(split_infos=split_infos)
92+
return splits_dict[self.split] # will raise an error if split is not found
93+
9794
def __getitem__(self, key: SupportsIndex) -> Any:
9895
record = self.data_source[key.__index__()]
9996
return self._deserialize(record)
@@ -133,7 +130,7 @@ def __repr__(self) -> str:
133130
)
134131

135132
def __len__(self) -> int:
136-
return self.data_source.__len__()
133+
return sum(fi.take for fi in self.split_info.file_instructions)
137134

138135
def __iter__(self):
139136
for i in range(self.__len__()):

tensorflow_datasets/core/data_sources/base_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def test_read_write(
9494
for i, element in enumerate(data_source):
9595
assert element == {'id': i}
9696

97+
# Also works on sliced splits.
98+
data_source = builder.as_data_source(split='train[0:2]')
99+
assert len(data_source) == 2
100+
data_source = builder.as_data_source(split='train[:50%]')
101+
assert len(data_source) == 2
102+
97103

98104
_FILE_INSTRUCTIONS = [
99105
shard_utils.FileInstruction(

tensorflow_datasets/core/data_sources/parquet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ParquetDataSource(base.BaseDataSource):
5757
"""ParquetDataSource to read from a ParquetDataset."""
5858

5959
def __post_init__(self):
60-
file_instructions = base.file_instructions(self.dataset_info, self.split)
60+
file_instructions = self.split_info.file_instructions
6161
filenames = [
6262
file_instruction.filename for file_instruction in file_instructions
6363
]

0 commit comments

Comments
 (0)