22
22
class ShardConfigTest (parameterized .TestCase ):
23
23
24
24
@parameterized .named_parameters (
25
- ('imagenet train, 137 GiB' , 137 << 30 , 1281167 , True , 1024 ),
26
- ('imagenet evaluation, 6.3 GiB' , 6300 * (1 << 20 ), 50000 , True , 64 ),
27
- ('very large, but few examples, 52 GiB' , 52 << 30 , 512 , True , 512 ),
28
- ('xxl, 10 TiB' , 10 << 40 , 10 ** 9 , True , 11264 ),
29
- ('xxl, 10 PiB, 100B examples' , 10 << 50 , 10 ** 11 , True , 10487808 ),
30
- ('xs, 100 MiB, 100K records' , 10 << 20 , 100 * 10 ** 3 , True , 1 ),
31
- ('m, 499 MiB, 200K examples' , 400 << 20 , 200 * 10 ** 3 , True , 4 ),
25
+ dict (
26
+ testcase_name = 'imagenet train, 137 GiB' ,
27
+ total_size = 137 << 30 ,
28
+ num_examples = 1281167 ,
29
+ uses_precise_sharding = True ,
30
+ max_size = None ,
31
+ expected_num_shards = 1024 ,
32
+ ),
33
+ dict (
34
+ testcase_name = 'imagenet evaluation, 6.3 GiB' ,
35
+ total_size = 6300 * (1 << 20 ),
36
+ num_examples = 50000 ,
37
+ uses_precise_sharding = True ,
38
+ max_size = None ,
39
+ expected_num_shards = 64 ,
40
+ ),
41
+ dict (
42
+ testcase_name = 'very large, but few examples, 52 GiB' ,
43
+ total_size = 52 << 30 ,
44
+ num_examples = 512 ,
45
+ uses_precise_sharding = True ,
46
+ max_size = None ,
47
+ expected_num_shards = 512 ,
48
+ ),
49
+ dict (
50
+ testcase_name = 'xxl, 10 TiB' ,
51
+ total_size = 10 << 40 ,
52
+ num_examples = 10 ** 9 ,
53
+ uses_precise_sharding = True ,
54
+ max_size = None ,
55
+ expected_num_shards = 11264 ,
56
+ ),
57
+ dict (
58
+ testcase_name = 'xxl, 10 PiB, 100B examples' ,
59
+ total_size = 10 << 50 ,
60
+ num_examples = 10 ** 11 ,
61
+ uses_precise_sharding = True ,
62
+ max_size = None ,
63
+ expected_num_shards = 10487808 ,
64
+ ),
65
+ dict (
66
+ testcase_name = 'xs, 100 MiB, 100K records' ,
67
+ total_size = 10 << 20 ,
68
+ num_examples = 100 * 10 ** 3 ,
69
+ uses_precise_sharding = True ,
70
+ max_size = None ,
71
+ expected_num_shards = 1 ,
72
+ ),
73
+ dict (
74
+ testcase_name = 'm, 499 MiB, 200K examples' ,
75
+ total_size = 400 << 20 ,
76
+ num_examples = 200 * 10 ** 3 ,
77
+ uses_precise_sharding = True ,
78
+ max_size = None ,
79
+ expected_num_shards = 4 ,
80
+ ),
81
+ dict (
82
+ testcase_name = '100GiB, even example sizes' ,
83
+ num_examples = 1e9 , # 1B examples
84
+ total_size = 1e9 * 1000 , # On average 1000 bytes per example
85
+ max_size = 1000 , # Max example size is 4000 bytes
86
+ uses_precise_sharding = True ,
87
+ expected_num_shards = 1024 ,
88
+ ),
89
+ dict (
90
+ testcase_name = '100GiB, uneven example sizes' ,
91
+ num_examples = 1e9 , # 1B examples
92
+ total_size = 1e9 * 1000 , # On average 1000 bytes per example
93
+ max_size = 4 * 1000 , # Max example size is 4000 bytes
94
+ uses_precise_sharding = True ,
95
+ expected_num_shards = 4096 ,
96
+ ),
97
+ dict (
98
+ testcase_name = '100GiB, very uneven example sizes' ,
99
+ num_examples = 1e9 , # 1B examples
100
+ total_size = 1e9 * 1000 , # On average 1000 bytes per example
101
+ max_size = 16 * 1000 , # Max example size is 16x the average bytes
102
+ uses_precise_sharding = True ,
103
+ expected_num_shards = 15360 ,
104
+ ),
32
105
)
33
106
def test_get_number_shards_default_config (
34
- self , total_size , num_examples , uses_precise_sharding , expected_num_shards
107
+ self ,
108
+ total_size : int ,
109
+ num_examples : int ,
110
+ uses_precise_sharding : bool ,
111
+ max_size : int ,
112
+ expected_num_shards : int ,
35
113
):
36
114
shard_config = shard_utils .ShardConfig ()
37
115
self .assertEqual (
38
116
expected_num_shards ,
39
117
shard_config .get_number_shards (
40
118
total_size = total_size ,
41
119
num_examples = num_examples ,
120
+ max_size = max_size , # max(1, total_size // num_examples),
42
121
uses_precise_sharding = uses_precise_sharding ,
43
122
),
44
123
)
@@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
48
127
self .assertEqual (
49
128
42 ,
50
129
shard_config .get_number_shards (
51
- total_size = 100 , num_examples = 1 , uses_precise_sharding = True
130
+ total_size = 100 ,
131
+ max_size = 100 ,
132
+ num_examples = 1 ,
133
+ uses_precise_sharding = True ,
52
134
),
53
135
)
54
136
0 commit comments