Skip to content

Commit 21c0be4

Browse files
tomvdwThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Use the max size of serialized examples to find a safe number of shards
If we know the max size of serialized examples, then we can account for the worst case scenario where one shard would get only examples of the max size. This hopefully should prevent users running into problems with having too big shards. PiperOrigin-RevId: 726377778
1 parent d0c47fc commit 21c0be4

File tree

8 files changed

+171
-26
lines changed

8 files changed

+171
-26
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def _compute_shard_specs(
406406
# HF split size is good enough for estimating the number of shards.
407407
num_shards = shard_utils.ShardConfig.calculate_number_shards(
408408
total_size=hf_split_info.num_bytes,
409+
max_example_size=None,
409410
num_examples=hf_split_info.num_examples,
410411
uses_precise_sharding=False,
411412
)

tensorflow_datasets/core/reader_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def _write_tfrecord(self, split_name, shards_number, records):
9797
shard_specs = writer_lib._get_shard_specs(
9898
num_examples=num_examples,
9999
total_size=0,
100+
max_example_size=None,
100101
bucket_lengths=[num_examples],
101102
filename_template=filename_template,
102103
shard_config=shard_utils.ShardConfig(num_shards=shards_number),

tensorflow_datasets/core/shuffle.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__(
244244
self._total_bytes = 0
245245
# To keep data in memory until enough data has been gathered.
246246
self._in_memory = True
247-
self._mem_buffer = []
247+
self._mem_buffer: list[type_utils.KeySerializedExample] = []
248248
self._seen_keys: set[int] = set()
249249
self._num_examples = 0
250250

@@ -272,10 +272,10 @@ def _add_to_mem_buffer(self, hkey: int, data: bytes) -> None:
272272
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
273273
for hkey, data in self._mem_buffer:
274274
self._add_to_bucket(hkey, data)
275-
self._mem_buffer = None
275+
self._mem_buffer = []
276276
self._in_memory = False
277277

278-
def add(self, key: type_utils.Key, data: bytes) -> bool:
278+
def add(self, key: type_utils.Key, data: bytes) -> None:
279279
"""Add (key, data) to shuffler."""
280280
if self._read_only:
281281
raise AssertionError('add() cannot be called after __iter__.')

tensorflow_datasets/core/shuffle_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def _test_items(self, salt, items, expected_order, disable_shuffling=False):
199199
for key, item in items:
200200
shuffler.add(key, item)
201201
self.assertEqual(shuffler.size, _TOTAL_SIZE)
202+
self.assertGreater(shuffler.max_size, 0)
202203
if not shuffler._in_memory: # Check size of temporary bucket files
203204
expected_size = (16 + 8) * len(items) + sum(len(t[1]) for t in items)
204205
size = 0

tensorflow_datasets/core/utils/shard_utils.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,47 @@ class ShardConfig:
5757
def calculate_number_shards(
5858
cls,
5959
total_size: int,
60+
max_example_size: int | Sequence[int] | None,
6061
num_examples: int,
6162
uses_precise_sharding: bool = True,
6263
) -> int:
6364
"""Returns number of shards for num_examples of total_size in bytes.
6465
6566
Args:
66-
total_size: the size of the data (serialized, not couting any overhead).
67+
total_size: the size of the data (serialized, not counting any overhead).
68+
max_example_size: the maximum size of a single example (serialized, not
69+
counting any overhead).
6770
num_examples: the number of records in the data.
6871
uses_precise_sharding: whether a mechanism is used to exactly control how
6972
many examples go in each shard.
7073
"""
71-
total_size += num_examples * cls.overhead
72-
max_shards_number = total_size // cls.min_shard_size
74+
total_overhead = num_examples * cls.overhead
75+
total_size_with_overhead = total_size + total_overhead
7376
if uses_precise_sharding:
7477
max_shard_size = cls.max_shard_size
7578
else:
7679
# When the pipeline does not control exactly how many rows go into each
7780
# shard (called 'precise sharding' here), we use a smaller max shard size
7881
# so that the pipeline doesn't fail if a shard gets some more examples.
7982
max_shard_size = 0.9 * cls.max_shard_size
80-
min_shards_number = total_size // max_shard_size
83+
max_shard_size = max(1, max_shard_size)
84+
85+
if max_example_size is None:
86+
min_shards_number = max(1, total_size_with_overhead // max_shard_size)
87+
max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size)
88+
else:
89+
if isinstance(max_example_size, Sequence):
90+
if len(max_example_size) == 1:
91+
max_example_size = max_example_size[0]
92+
else:
93+
raise ValueError(
94+
'max_example_size must be a single value or None, got'
95+
f' {max_example_size}'
96+
)
97+
pessimistic_total_size = num_examples * (max_example_size + cls.overhead)
98+
min_shards_number = max(1, pessimistic_total_size // max_shard_size)
99+
max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size)
100+
81101
if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024:
82102
return 1024
83103
elif min_shards_number > 1024:
@@ -96,15 +116,22 @@ def calculate_number_shards(
96116
def get_number_shards(
97117
self,
98118
total_size: int,
119+
max_example_size: int | None,
99120
num_examples: int,
100121
uses_precise_sharding: bool = True,
101122
) -> int:
102123
if self.num_shards:
103124
return self.num_shards
104125
return self.calculate_number_shards(
105-
total_size, num_examples, uses_precise_sharding
126+
total_size=total_size,
127+
max_example_size=max_example_size,
128+
num_examples=num_examples,
129+
uses_precise_sharding=uses_precise_sharding,
106130
)
107131

132+
def replace(self, **kwargs: Any) -> ShardConfig:
133+
return dataclasses.replace(self, **kwargs)
134+
108135

109136
def get_shard_boundaries(
110137
num_examples: int,

tensorflow_datasets/core/utils/shard_utils_test.py

+91-9
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,102 @@
2222
class ShardConfigTest(parameterized.TestCase):
2323

2424
@parameterized.named_parameters(
25-
('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024),
26-
('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64),
27-
('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512),
28-
('xxl, 10 TiB', 10 << 40, 10**9, True, 11264),
29-
('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808),
30-
('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1),
31-
('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4),
25+
dict(
26+
testcase_name='imagenet train, 137 GiB',
27+
total_size=137 << 30,
28+
num_examples=1281167,
29+
uses_precise_sharding=True,
30+
max_size=None,
31+
expected_num_shards=1024,
32+
),
33+
dict(
34+
testcase_name='imagenet evaluation, 6.3 GiB',
35+
total_size=6300 * (1 << 20),
36+
num_examples=50000,
37+
uses_precise_sharding=True,
38+
max_size=None,
39+
expected_num_shards=64,
40+
),
41+
dict(
42+
testcase_name='very large, but few examples, 52 GiB',
43+
total_size=52 << 30,
44+
num_examples=512,
45+
uses_precise_sharding=True,
46+
max_size=None,
47+
expected_num_shards=512,
48+
),
49+
dict(
50+
testcase_name='xxl, 10 TiB',
51+
total_size=10 << 40,
52+
num_examples=10**9,
53+
uses_precise_sharding=True,
54+
max_size=None,
55+
expected_num_shards=11264,
56+
),
57+
dict(
58+
testcase_name='xxl, 10 PiB, 100B examples',
59+
total_size=10 << 50,
60+
num_examples=10**11,
61+
uses_precise_sharding=True,
62+
max_size=None,
63+
expected_num_shards=10487808,
64+
),
65+
dict(
66+
testcase_name='xs, 100 MiB, 100K records',
67+
total_size=10 << 20,
68+
num_examples=100 * 10**3,
69+
uses_precise_sharding=True,
70+
max_size=None,
71+
expected_num_shards=1,
72+
),
73+
dict(
74+
testcase_name='m, 499 MiB, 200K examples',
75+
total_size=400 << 20,
76+
num_examples=200 * 10**3,
77+
uses_precise_sharding=True,
78+
max_size=None,
79+
expected_num_shards=4,
80+
),
81+
dict(
82+
testcase_name='100GiB, even example sizes',
83+
num_examples=1e9, # 1B examples
84+
total_size=1e9 * 1000, # On average 1000 bytes per example
85+
max_size=1000, # Max example size is 4000 bytes
86+
uses_precise_sharding=True,
87+
expected_num_shards=1024,
88+
),
89+
dict(
90+
testcase_name='100GiB, uneven example sizes',
91+
num_examples=1e9, # 1B examples
92+
total_size=1e9 * 1000, # On average 1000 bytes per example
93+
max_size=4 * 1000, # Max example size is 4000 bytes
94+
uses_precise_sharding=True,
95+
expected_num_shards=4096,
96+
),
97+
dict(
98+
testcase_name='100GiB, very uneven example sizes',
99+
num_examples=1e9, # 1B examples
100+
total_size=1e9 * 1000, # On average 1000 bytes per example
101+
max_size=16 * 1000, # Max example size is 16x the average bytes
102+
uses_precise_sharding=True,
103+
expected_num_shards=15360,
104+
),
32105
)
33106
def test_get_number_shards_default_config(
34-
self, total_size, num_examples, uses_precise_sharding, expected_num_shards
107+
self,
108+
total_size: int,
109+
num_examples: int,
110+
uses_precise_sharding: bool,
111+
max_size: int,
112+
expected_num_shards: int,
35113
):
36114
shard_config = shard_utils.ShardConfig()
37115
self.assertEqual(
38116
expected_num_shards,
39117
shard_config.get_number_shards(
40118
total_size=total_size,
41119
num_examples=num_examples,
120+
max_example_size=max_size, # max(1, total_size // num_examples),
42121
uses_precise_sharding=uses_precise_sharding,
43122
),
44123
)
@@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
48127
self.assertEqual(
49128
42,
50129
shard_config.get_number_shards(
51-
total_size=100, num_examples=1, uses_precise_sharding=True
130+
total_size=100,
131+
max_example_size=100,
132+
num_examples=1,
133+
uses_precise_sharding=True,
52134
),
53135
)
54136

tensorflow_datasets/core/writer.py

+40-9
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,26 @@ def _get_index_path(path: str) -> epath.PathLike:
116116
def _get_shard_specs(
117117
num_examples: int,
118118
total_size: int,
119+
max_example_size: int | None,
119120
bucket_lengths: Sequence[int],
120121
filename_template: naming.ShardedFileTemplate,
121122
shard_config: shard_utils.ShardConfig,
122123
) -> Sequence[_ShardSpec]:
123124
"""Returns list of _ShardSpec instances, corresponding to shards to write.
124125
125126
Args:
126-
num_examples: int, number of examples in split.
127-
total_size: int (bytes), sum of example sizes.
127+
num_examples: number of examples in split.
128+
total_size: total size in bytes, i.e., the sum of example sizes.
129+
max_example_size: maximum size in bytes of a single example.
128130
bucket_lengths: list of ints, number of examples in each bucket.
129131
filename_template: template to format sharded filenames.
130132
shard_config: the configuration for creating shards.
131133
"""
132-
num_shards = shard_config.get_number_shards(total_size, num_examples)
134+
num_shards = shard_config.get_number_shards(
135+
total_size=total_size,
136+
max_example_size=max_example_size,
137+
num_examples=num_examples,
138+
)
133139
shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards)
134140
shard_specs = []
135141
bucket_indexes = [str(i) for i in range(len(bucket_lengths))]
@@ -350,6 +356,7 @@ def __init__(
350356
self._filename_template = filename_template
351357
self._shard_config = shard_config or shard_utils.ShardConfig()
352358
self._example_writer = example_writer
359+
self._max_example_size = 0
353360

354361
def write(self, key: int | bytes, example: Example):
355362
"""Writes given example.
@@ -363,6 +370,9 @@ def write(self, key: int | bytes, example: Example):
363370
"""
364371
serialized_example = self._serializer.serialize_example(example=example)
365372
self._shuffler.add(key, serialized_example)
373+
self._max_example_size = max(
374+
self._max_example_size, len(serialized_example)
375+
)
366376

367377
def finalize(self) -> tuple[list[int], int]:
368378
"""Effectively writes examples to the shards."""
@@ -372,6 +382,7 @@ def finalize(self) -> tuple[list[int], int]:
372382
shard_specs = _get_shard_specs(
373383
num_examples=self._shuffler.num_examples,
374384
total_size=self._shuffler.size,
385+
max_example_size=self._max_example_size,
375386
bucket_lengths=self._shuffler.bucket_lengths,
376387
filename_template=self._filename_template,
377388
shard_config=self._shard_config,
@@ -589,10 +600,13 @@ def _write_final_shard(
589600
id=shard_id, num_examples=len(example_by_key), size=shard_size
590601
)
591602

592-
def _number_of_shards(self, num_examples: int, total_size: int) -> int:
603+
def _number_of_shards(
604+
self, num_examples: int, total_size: int, max_example_size: int
605+
) -> int:
593606
"""Returns the number of shards."""
594607
num_shards = self._shard_config.get_number_shards(
595608
total_size=total_size,
609+
max_example_size=max_example_size,
596610
num_examples=num_examples,
597611
uses_precise_sharding=False,
598612
)
@@ -658,16 +672,26 @@ def write_from_pcollection(self, examples_pcollection):
658672
| "CountExamples" >> beam.combiners.Count.Globally()
659673
| "CheckValidNumExamples" >> beam.Map(self._check_num_examples)
660674
)
675+
serialized_example_sizes = (
676+
serialized_examples | beam.Values() | beam.Map(len)
677+
)
661678
total_size = beam.pvalue.AsSingleton(
662-
serialized_examples
663-
| beam.Values()
664-
| beam.Map(len)
665-
| "TotalSize" >> beam.CombineGlobally(sum)
679+
serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum)
680+
)
681+
682+
max_example_size = beam.pvalue.AsSingleton(
683+
serialized_example_sizes
684+
| "TopExampleSize" >> beam.combiners.Top.Largest(1)
685+
| "MaxExampleSize" >> beam.CombineGlobally(_get_max_size)
666686
)
667687
ideal_num_shards = beam.pvalue.AsSingleton(
668688
num_examples
669689
| "NumberOfShards"
670-
>> beam.Map(self._number_of_shards, total_size=total_size)
690+
>> beam.Map(
691+
self._number_of_shards,
692+
total_size=total_size,
693+
max_example_size=max_example_size,
694+
)
671695
)
672696

673697
examples_per_shard = (
@@ -826,3 +850,10 @@ def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]:
826850
)
827851

828852
return shard_lengths, total_size_bytes
853+
854+
855+
def _get_max_size(sizes: Iterable[int]) -> int | None:
856+
sizes = list(sizes)
857+
if not sizes:
858+
return None
859+
return max(sizes)

tensorflow_datasets/core/writer_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_1bucket_6shards(self):
4848
filetype_suffix='tfrecord',
4949
),
5050
shard_config=shard_utils.ShardConfig(num_shards=6),
51+
max_example_size=2,
5152
)
5253
self.assertEqual(
5354
specs,
@@ -134,6 +135,7 @@ def test_4buckets_2shards(self):
134135
filetype_suffix='tfrecord',
135136
),
136137
shard_config=shard_utils.ShardConfig(num_shards=2),
138+
max_example_size=2,
137139
)
138140
self.assertEqual(
139141
specs,

0 commit comments

Comments
 (0)