Skip to content

Commit 2e87765

Browse files
tomvdwThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Use the max size of serialized examples to find a safe number of shards
If we know the max size of serialized examples, then we can account for the worst case scenario where one shard would get only examples of the max size. This hopefully should prevent users running into problems with having too big shards. PiperOrigin-RevId: 726377778
1 parent 281ce2d commit 2e87765

File tree

6 files changed

+157
-24
lines changed

6 files changed

+157
-24
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def _compute_shard_specs(
406406
# HF split size is good enough for estimating the number of shards.
407407
num_shards = shard_utils.ShardConfig.calculate_number_shards(
408408
total_size=hf_split_info.num_bytes,
409+
max_size=None,
409410
num_examples=hf_split_info.num_examples,
410411
uses_precise_sharding=False,
411412
)

tensorflow_datasets/core/shuffle.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,10 @@ def __init__(
244244
self._total_bytes = 0
245245
# To keep data in memory until enough data has been gathered.
246246
self._in_memory = True
247-
self._mem_buffer = []
247+
self._mem_buffer: list[type_utils.KeySerializedExample] = []
248248
self._seen_keys: set[int] = set()
249249
self._num_examples = 0
250+
self._max_size = 0
250251

251252
@property
252253
def size(self) -> int:
@@ -263,6 +264,10 @@ def bucket_lengths(self) -> Sequence[int]:
263264
def num_examples(self) -> int:
264265
return self._num_examples
265266

267+
@property
268+
def max_size(self) -> int:
269+
return self._max_size
270+
266271
def _add_to_bucket(self, hkey: int, data: bytes) -> None:
267272
bucket_number = get_bucket_number(hkey=hkey, num_buckets=BUCKETS_NUMBER)
268273
self._buckets[bucket_number].add(hkey, data)
@@ -272,10 +277,10 @@ def _add_to_mem_buffer(self, hkey: int, data: bytes) -> None:
272277
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
273278
for hkey, data in self._mem_buffer:
274279
self._add_to_bucket(hkey, data)
275-
self._mem_buffer = None
280+
self._mem_buffer = []
276281
self._in_memory = False
277282

278-
def add(self, key: type_utils.Key, data: bytes) -> bool:
283+
def add(self, key: type_utils.Key, data: bytes) -> None:
279284
"""Add (key, data) to shuffler."""
280285
if self._read_only:
281286
raise AssertionError('add() cannot be called after __iter__.')
@@ -299,6 +304,7 @@ def add(self, key: type_utils.Key, data: bytes) -> bool:
299304
else:
300305
self._add_to_bucket(hkey, data)
301306
self._num_examples += 1 # pytype: disable=bad-return-type
307+
self._max_size = max(self._max_size, len(data))
302308

303309
def __iter__(self) -> Iterator[type_utils.KeySerializedExample]:
304310
self._read_only = True

tensorflow_datasets/core/utils/shard_utils.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,39 @@ class ShardConfig:
5757
def calculate_number_shards(
5858
cls,
5959
total_size: int,
60+
max_size: int | None,
6061
num_examples: int,
6162
uses_precise_sharding: bool = True,
6263
) -> int:
6364
"""Returns number of shards for num_examples of total_size in bytes.
6465
6566
Args:
66-
total_size: the size of the data (serialized, not couting any overhead).
67+
total_size: the size of the data (serialized, not counting any overhead).
68+
max_size: the maximum size of a single example (serialized, not counting
69+
any overhead).
6770
num_examples: the number of records in the data.
6871
uses_precise_sharding: whether a mechanism is used to exactly control how
6972
many examples go in each shard.
7073
"""
71-
total_size += num_examples * cls.overhead
72-
max_shards_number = total_size // cls.min_shard_size
74+
total_overhead = num_examples * cls.overhead
75+
total_size_with_overhead = total_size + total_overhead
7376
if uses_precise_sharding:
7477
max_shard_size = cls.max_shard_size
7578
else:
7679
# When the pipeline does not control exactly how many rows go into each
7780
# shard (called 'precise sharding' here), we use a smaller max shard size
7881
# so that the pipeline doesn't fail if a shard gets some more examples.
7982
max_shard_size = 0.9 * cls.max_shard_size
80-
min_shards_number = total_size // max_shard_size
83+
max_shard_size = max(1, max_shard_size)
84+
85+
if max_size is None:
86+
min_shards_number = max(1, total_size_with_overhead // max_shard_size)
87+
max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size)
88+
else:
89+
pessimistic_total_size = num_examples * (max_size + cls.overhead)
90+
min_shards_number = max(1, pessimistic_total_size // max_shard_size)
91+
max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size)
92+
8193
if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024:
8294
return 1024
8395
elif min_shards_number > 1024:
@@ -96,15 +108,22 @@ def calculate_number_shards(
96108
def get_number_shards(
97109
self,
98110
total_size: int,
111+
max_size: int | None,
99112
num_examples: int,
100113
uses_precise_sharding: bool = True,
101114
) -> int:
102115
if self.num_shards:
103116
return self.num_shards
104117
return self.calculate_number_shards(
105-
total_size, num_examples, uses_precise_sharding
118+
total_size=total_size,
119+
max_size=max_size,
120+
num_examples=num_examples,
121+
uses_precise_sharding=uses_precise_sharding,
106122
)
107123

124+
def replace(self, **kwargs: Any) -> ShardConfig:
125+
return dataclasses.replace(self, **kwargs)
126+
108127

109128
def get_shard_boundaries(
110129
num_examples: int,

tensorflow_datasets/core/utils/shard_utils_test.py

+91-9
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,102 @@
2222
class ShardConfigTest(parameterized.TestCase):
2323

2424
@parameterized.named_parameters(
25-
('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024),
26-
('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64),
27-
('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512),
28-
('xxl, 10 TiB', 10 << 40, 10**9, True, 11264),
29-
('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808),
30-
('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1),
31-
('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4),
25+
dict(
26+
testcase_name='imagenet train, 137 GiB',
27+
total_size=137 << 30,
28+
num_examples=1281167,
29+
uses_precise_sharding=True,
30+
max_size=None,
31+
expected_num_shards=1024,
32+
),
33+
dict(
34+
testcase_name='imagenet evaluation, 6.3 GiB',
35+
total_size=6300 * (1 << 20),
36+
num_examples=50000,
37+
uses_precise_sharding=True,
38+
max_size=None,
39+
expected_num_shards=64,
40+
),
41+
dict(
42+
testcase_name='very large, but few examples, 52 GiB',
43+
total_size=52 << 30,
44+
num_examples=512,
45+
uses_precise_sharding=True,
46+
max_size=None,
47+
expected_num_shards=512,
48+
),
49+
dict(
50+
testcase_name='xxl, 10 TiB',
51+
total_size=10 << 40,
52+
num_examples=10**9,
53+
uses_precise_sharding=True,
54+
max_size=None,
55+
expected_num_shards=11264,
56+
),
57+
dict(
58+
testcase_name='xxl, 10 PiB, 100B examples',
59+
total_size=10 << 50,
60+
num_examples=10**11,
61+
uses_precise_sharding=True,
62+
max_size=None,
63+
expected_num_shards=10487808,
64+
),
65+
dict(
66+
testcase_name='xs, 100 MiB, 100K records',
67+
total_size=10 << 20,
68+
num_examples=100 * 10**3,
69+
uses_precise_sharding=True,
70+
max_size=None,
71+
expected_num_shards=1,
72+
),
73+
dict(
74+
testcase_name='m, 499 MiB, 200K examples',
75+
total_size=400 << 20,
76+
num_examples=200 * 10**3,
77+
uses_precise_sharding=True,
78+
max_size=None,
79+
expected_num_shards=4,
80+
),
81+
dict(
82+
testcase_name='100GiB, even example sizes',
83+
num_examples=1e9, # 1B examples
84+
total_size=1e9 * 1000, # On average 1000 bytes per example
85+
max_size=1000, # Max example size is 4000 bytes
86+
uses_precise_sharding=True,
87+
expected_num_shards=1024,
88+
),
89+
dict(
90+
testcase_name='100GiB, uneven example sizes',
91+
num_examples=1e9, # 1B examples
92+
total_size=1e9 * 1000, # On average 1000 bytes per example
93+
max_size=4 * 1000, # Max example size is 4000 bytes
94+
uses_precise_sharding=True,
95+
expected_num_shards=4096,
96+
),
97+
dict(
98+
testcase_name='100GiB, very uneven example sizes',
99+
num_examples=1e9, # 1B examples
100+
total_size=1e9 * 1000, # On average 1000 bytes per example
101+
max_size=16 * 1000, # Max example size is 16x the average bytes
102+
uses_precise_sharding=True,
103+
expected_num_shards=15360,
104+
),
32105
)
33106
def test_get_number_shards_default_config(
34-
self, total_size, num_examples, uses_precise_sharding, expected_num_shards
107+
self,
108+
total_size: int,
109+
num_examples: int,
110+
uses_precise_sharding: bool,
111+
max_size: int,
112+
expected_num_shards: int,
35113
):
36114
shard_config = shard_utils.ShardConfig()
37115
self.assertEqual(
38116
expected_num_shards,
39117
shard_config.get_number_shards(
40118
total_size=total_size,
41119
num_examples=num_examples,
120+
max_size=max_size, # max(1, total_size // num_examples),
42121
uses_precise_sharding=uses_precise_sharding,
43122
),
44123
)
@@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
48127
self.assertEqual(
49128
42,
50129
shard_config.get_number_shards(
51-
total_size=100, num_examples=1, uses_precise_sharding=True
130+
total_size=100,
131+
max_size=100,
132+
num_examples=1,
133+
uses_precise_sharding=True,
52134
),
53135
)
54136

tensorflow_datasets/core/writer.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _get_index_path(path: str) -> epath.PathLike:
116116
def _get_shard_specs(
117117
num_examples: int,
118118
total_size: int,
119+
max_size: int | None,
119120
bucket_lengths: Sequence[int],
120121
filename_template: naming.ShardedFileTemplate,
121122
shard_config: shard_utils.ShardConfig,
@@ -125,11 +126,14 @@ def _get_shard_specs(
125126
Args:
126127
num_examples: int, number of examples in split.
127128
total_size: int (bytes), sum of example sizes.
129+
max_size: int (bytes), maximum size of a single example.
128130
bucket_lengths: list of ints, number of examples in each bucket.
129131
filename_template: template to format sharded filenames.
130132
shard_config: the configuration for creating shards.
131133
"""
132-
num_shards = shard_config.get_number_shards(total_size, num_examples)
134+
num_shards = shard_config.get_number_shards(
135+
total_size=total_size, max_size=max_size, num_examples=num_examples
136+
)
133137
shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards)
134138
shard_specs = []
135139
bucket_indexes = [str(i) for i in range(len(bucket_lengths))]
@@ -372,6 +376,7 @@ def finalize(self) -> tuple[list[int], int]:
372376
shard_specs = _get_shard_specs(
373377
num_examples=self._shuffler.num_examples,
374378
total_size=self._shuffler.size,
379+
max_size=self._shuffler.max_size,
375380
bucket_lengths=self._shuffler.bucket_lengths,
376381
filename_template=self._filename_template,
377382
shard_config=self._shard_config,
@@ -589,10 +594,13 @@ def _write_final_shard(
589594
id=shard_id, num_examples=len(example_by_key), size=shard_size
590595
)
591596

592-
def _number_of_shards(self, num_examples: int, total_size: int) -> int:
597+
def _number_of_shards(
598+
self, num_examples: int, total_size: int, max_size: int
599+
) -> int:
593600
"""Returns the number of shards."""
594601
num_shards = self._shard_config.get_number_shards(
595602
total_size=total_size,
603+
max_size=max_size,
596604
num_examples=num_examples,
597605
uses_precise_sharding=False,
598606
)
@@ -658,16 +666,24 @@ def write_from_pcollection(self, examples_pcollection):
658666
| "CountExamples" >> beam.combiners.Count.Globally()
659667
| "CheckValidNumExamples" >> beam.Map(self._check_num_examples)
660668
)
669+
serialized_example_sizes = (
670+
serialized_examples | beam.Values() | beam.Map(len)
671+
)
661672
total_size = beam.pvalue.AsSingleton(
662-
serialized_examples
663-
| beam.Values()
664-
| beam.Map(len)
665-
| "TotalSize" >> beam.CombineGlobally(sum)
673+
serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum)
674+
)
675+
676+
max_size = beam.pvalue.AsSingleton(
677+
serialized_example_sizes
678+
| "TopExampleSize" >> beam.combiners.Top.Largest(1)
679+
| "MaxExampleSize" >> beam.CombineGlobally(_get_max_size)
666680
)
667681
ideal_num_shards = beam.pvalue.AsSingleton(
668682
num_examples
669683
| "NumberOfShards"
670-
>> beam.Map(self._number_of_shards, total_size=total_size)
684+
>> beam.Map(
685+
self._number_of_shards, total_size=total_size, max_size=max_size
686+
)
671687
)
672688

673689
examples_per_shard = (
@@ -826,3 +842,10 @@ def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]:
826842
)
827843

828844
return shard_lengths, total_size_bytes
845+
846+
847+
def _get_max_size(sizes: Iterable[int]) -> int | None:
848+
sizes = list(sizes)
849+
if not sizes:
850+
return None
851+
return max(sizes)

tensorflow_datasets/core/writer_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_1bucket_6shards(self):
4848
filetype_suffix='tfrecord',
4949
),
5050
shard_config=shard_utils.ShardConfig(num_shards=6),
51+
max_size=2,
5152
)
5253
self.assertEqual(
5354
specs,
@@ -134,6 +135,7 @@ def test_4buckets_2shards(self):
134135
filetype_suffix='tfrecord',
135136
),
136137
shard_config=shard_utils.ShardConfig(num_shards=2),
138+
max_size=2,
137139
)
138140
self.assertEqual(
139141
specs,

0 commit comments

Comments
 (0)