Skip to content

Commit 46f88b3

Browse files
tomvdwThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Use the max size of serialized examples to find a safe number of shards
If we know the max size of serialized examples, then we can account for the worst case scenario where one shard would get only examples of the max size. This hopefully should prevent users running into problems with having too big shards. PiperOrigin-RevId: 726377778
1 parent 281ce2d commit 46f88b3

File tree

4 files changed

+136
-21
lines changed

4 files changed

+136
-21
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def _compute_shard_specs(
406406
# HF split size is good enough for estimating the number of shards.
407407
num_shards = shard_utils.ShardConfig.calculate_number_shards(
408408
total_size=hf_split_info.num_bytes,
409+
max_size=None,
409410
num_examples=hf_split_info.num_examples,
410411
uses_precise_sharding=False,
411412
)

tensorflow_datasets/core/utils/shard_utils.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,39 @@ class ShardConfig:
5757
def calculate_number_shards(
5858
cls,
5959
total_size: int,
60+
max_size: int | None,
6061
num_examples: int,
6162
uses_precise_sharding: bool = True,
6263
) -> int:
6364
"""Returns number of shards for num_examples of total_size in bytes.
6465
6566
Args:
66-
total_size: the size of the data (serialized, not couting any overhead).
67+
total_size: the size of the data (serialized, not counting any overhead).
68+
max_size: the maximum size of a single example (serialized, not counting
69+
any overhead).
6770
num_examples: the number of records in the data.
6871
uses_precise_sharding: whether a mechanism is used to exactly control how
6972
many examples go in each shard.
7073
"""
71-
total_size += num_examples * cls.overhead
72-
max_shards_number = total_size // cls.min_shard_size
74+
total_overhead = num_examples * cls.overhead
75+
total_size_with_overhead = total_size + total_overhead
7376
if uses_precise_sharding:
7477
max_shard_size = cls.max_shard_size
7578
else:
7679
# When the pipeline does not control exactly how many rows go into each
7780
# shard (called 'precise sharding' here), we use a smaller max shard size
7881
# so that the pipeline doesn't fail if a shard gets some more examples.
7982
max_shard_size = 0.9 * cls.max_shard_size
80-
min_shards_number = total_size // max_shard_size
83+
max_shard_size = max(1, max_shard_size)
84+
85+
if max_size is None:
86+
min_shards_number = max(1, total_size_with_overhead // max_shard_size)
87+
max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size)
88+
else:
89+
pessimistic_total_size = num_examples * (max_size + cls.overhead)
90+
min_shards_number = max(1, pessimistic_total_size // max_shard_size)
91+
max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size)
92+
8193
if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024:
8294
return 1024
8395
elif min_shards_number > 1024:
@@ -96,15 +108,22 @@ def calculate_number_shards(
96108
def get_number_shards(
97109
self,
98110
total_size: int,
111+
max_size: int | None,
99112
num_examples: int,
100113
uses_precise_sharding: bool = True,
101114
) -> int:
102115
if self.num_shards:
103116
return self.num_shards
104117
return self.calculate_number_shards(
105-
total_size, num_examples, uses_precise_sharding
118+
total_size=total_size,
119+
max_size=max_size,
120+
num_examples=num_examples,
121+
uses_precise_sharding=uses_precise_sharding,
106122
)
107123

124+
def replace(self, **kwargs: Any) -> ShardConfig:
125+
return dataclasses.replace(self, **kwargs)
126+
108127

109128
def get_shard_boundaries(
110129
num_examples: int,

tensorflow_datasets/core/utils/shard_utils_test.py

+91-9
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,102 @@
2222
class ShardConfigTest(parameterized.TestCase):
2323

2424
@parameterized.named_parameters(
25-
('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024),
26-
('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64),
27-
('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512),
28-
('xxl, 10 TiB', 10 << 40, 10**9, True, 11264),
29-
('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808),
30-
('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1),
31-
('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4),
25+
dict(
26+
testcase_name='imagenet train, 137 GiB',
27+
total_size=137 << 30,
28+
num_examples=1281167,
29+
uses_precise_sharding=True,
30+
max_size=None,
31+
expected_num_shards=1024,
32+
),
33+
dict(
34+
testcase_name='imagenet evaluation, 6.3 GiB',
35+
total_size=6300 * (1 << 20),
36+
num_examples=50000,
37+
uses_precise_sharding=True,
38+
max_size=None,
39+
expected_num_shards=64,
40+
),
41+
dict(
42+
testcase_name='very large, but few examples, 52 GiB',
43+
total_size=52 << 30,
44+
num_examples=512,
45+
uses_precise_sharding=True,
46+
max_size=None,
47+
expected_num_shards=512,
48+
),
49+
dict(
50+
testcase_name='xxl, 10 TiB',
51+
total_size=10 << 40,
52+
num_examples=10**9,
53+
uses_precise_sharding=True,
54+
max_size=None,
55+
expected_num_shards=11264,
56+
),
57+
dict(
58+
testcase_name='xxl, 10 PiB, 100B examples',
59+
total_size=10 << 50,
60+
num_examples=10**11,
61+
uses_precise_sharding=True,
62+
max_size=None,
63+
expected_num_shards=10487808,
64+
),
65+
dict(
66+
testcase_name='xs, 100 MiB, 100K records',
67+
total_size=10 << 20,
68+
num_examples=100 * 10**3,
69+
uses_precise_sharding=True,
70+
max_size=None,
71+
expected_num_shards=1,
72+
),
73+
dict(
74+
testcase_name='m, 499 MiB, 200K examples',
75+
total_size=400 << 20,
76+
num_examples=200 * 10**3,
77+
uses_precise_sharding=True,
78+
max_size=None,
79+
expected_num_shards=4,
80+
),
81+
dict(
82+
testcase_name='100GiB, even example sizes',
83+
num_examples=1e9, # 1B examples
84+
total_size=1e9 * 1000, # On average 1000 bytes per example
85+
max_size=1000, # Max example size is 4000 bytes
86+
uses_precise_sharding=True,
87+
expected_num_shards=1024,
88+
),
89+
dict(
90+
testcase_name='100GiB, uneven example sizes',
91+
num_examples=1e9, # 1B examples
92+
total_size=1e9 * 1000, # On average 1000 bytes per example
93+
max_size=4 * 1000, # Max example size is 4000 bytes
94+
uses_precise_sharding=True,
95+
expected_num_shards=4096,
96+
),
97+
dict(
98+
testcase_name='100GiB, very uneven example sizes',
99+
num_examples=1e9, # 1B examples
100+
total_size=1e9 * 1000, # On average 1000 bytes per example
101+
max_size=16 * 1000, # Max example size is 16x the average bytes
102+
uses_precise_sharding=True,
103+
expected_num_shards=15360,
104+
),
32105
)
33106
def test_get_number_shards_default_config(
34-
self, total_size, num_examples, uses_precise_sharding, expected_num_shards
107+
self,
108+
total_size: int,
109+
num_examples: int,
110+
uses_precise_sharding: bool,
111+
max_size: int,
112+
expected_num_shards: int,
35113
):
36114
shard_config = shard_utils.ShardConfig()
37115
self.assertEqual(
38116
expected_num_shards,
39117
shard_config.get_number_shards(
40118
total_size=total_size,
41119
num_examples=num_examples,
120+
max_size=max_size, # max(1, total_size // num_examples),
42121
uses_precise_sharding=uses_precise_sharding,
43122
),
44123
)
@@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
48127
self.assertEqual(
49128
42,
50129
shard_config.get_number_shards(
51-
total_size=100, num_examples=1, uses_precise_sharding=True
130+
total_size=100,
131+
max_size=100,
132+
num_examples=1,
133+
uses_precise_sharding=True,
52134
),
53135
)
54136

tensorflow_datasets/core/writer.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _get_index_path(path: str) -> epath.PathLike:
116116
def _get_shard_specs(
117117
num_examples: int,
118118
total_size: int,
119+
max_size: int | None,
119120
bucket_lengths: Sequence[int],
120121
filename_template: naming.ShardedFileTemplate,
121122
shard_config: shard_utils.ShardConfig,
@@ -125,11 +126,14 @@ def _get_shard_specs(
125126
Args:
126127
num_examples: int, number of examples in split.
127128
total_size: int (bytes), sum of example sizes.
129+
max_size: int (bytes), maximum size of a single example.
128130
bucket_lengths: list of ints, number of examples in each bucket.
129131
filename_template: template to format sharded filenames.
130132
shard_config: the configuration for creating shards.
131133
"""
132-
num_shards = shard_config.get_number_shards(total_size, num_examples)
134+
num_shards = shard_config.get_number_shards(
135+
total_size=total_size, max_size=max_size, num_examples=num_examples
136+
)
133137
shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards)
134138
shard_specs = []
135139
bucket_indexes = [str(i) for i in range(len(bucket_lengths))]
@@ -372,6 +376,7 @@ def finalize(self) -> tuple[list[int], int]:
372376
shard_specs = _get_shard_specs(
373377
num_examples=self._shuffler.num_examples,
374378
total_size=self._shuffler.size,
379+
max_size=None,
375380
bucket_lengths=self._shuffler.bucket_lengths,
376381
filename_template=self._filename_template,
377382
shard_config=self._shard_config,
@@ -589,10 +594,13 @@ def _write_final_shard(
589594
id=shard_id, num_examples=len(example_by_key), size=shard_size
590595
)
591596

592-
def _number_of_shards(self, num_examples: int, total_size: int) -> int:
597+
def _number_of_shards(
598+
self, num_examples: int, total_size: int, max_size: int
599+
) -> int:
593600
"""Returns the number of shards."""
594601
num_shards = self._shard_config.get_number_shards(
595602
total_size=total_size,
603+
max_size=max_size,
596604
num_examples=num_examples,
597605
uses_precise_sharding=False,
598606
)
@@ -658,16 +666,21 @@ def write_from_pcollection(self, examples_pcollection):
658666
| "CountExamples" >> beam.combiners.Count.Globally()
659667
| "CheckValidNumExamples" >> beam.Map(self._check_num_examples)
660668
)
669+
serialized_example_sizes = (
670+
serialized_examples | beam.Values() | beam.Map(len)
671+
)
661672
total_size = beam.pvalue.AsSingleton(
662-
serialized_examples
663-
| beam.Values()
664-
| beam.Map(len)
665-
| "TotalSize" >> beam.CombineGlobally(sum)
673+
serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum)
674+
)
675+
max_size = beam.pvalue.AsSingleton(
676+
serialized_example_sizes | "MaxSize" >> beam.CombineGlobally(max)
666677
)
667678
ideal_num_shards = beam.pvalue.AsSingleton(
668679
num_examples
669680
| "NumberOfShards"
670-
>> beam.Map(self._number_of_shards, total_size=total_size)
681+
>> beam.Map(
682+
self._number_of_shards, total_size=total_size, max_size=max_size
683+
)
671684
)
672685

673686
examples_per_shard = (

0 commit comments

Comments
 (0)