6
6
import pytest
7
7
8
8
from tiledb .ml .readers ._tensor_schema .ranges import (
9
- InclusiveRange ,
9
+ ConstrainedPartitionsIntRange ,
10
10
IntRange ,
11
11
WeightedRange ,
12
12
)
13
13
14
14
15
- @pytest .mark .parametrize ("values" , [None , 42 , 3.14 ])
16
- def test_inclusive_range_factory_type_error (values ):
17
- with pytest .raises (TypeError ) as excinfo :
18
- InclusiveRange .factory (values )
19
- assert "Cannot create inclusive range" in str (excinfo .value )
20
-
21
-
22
15
class TestIntRange :
23
- values = range (10 , 20 )
24
- r = InclusiveRange .factory (values )
16
+ r = IntRange (10 , 19 )
25
17
26
18
def test_basic (self ):
27
19
assert self .r .min == 10
28
20
assert self .r .max == 19
29
21
assert self .r .weight == 10
30
22
assert len (self .r ) == 10
31
23
32
- @pytest .mark .parametrize (
33
- "values" ,
34
- [
35
- values ,
36
- list (values ),
37
- set (values ),
38
- iter (values ),
39
- reversed (values ),
40
- Counter (values ),
41
- np .array (values ),
42
- range (19 , 9 , - 1 ),
43
- np .arange (19 , 9 , - 1 ),
44
- ],
45
- )
46
- def test_equal (self , values ):
47
- assert_equal_ranges (self .r , InclusiveRange .factory (values ), IntRange )
24
+ def test_equal (self ):
25
+ assert_equal_ranges (self .r , IntRange (10 , 19 ))
26
+ assert self .r != IntRange (0 , 9 )
27
+ assert self .r != IntRange (10 , 20 )
28
+ assert self .r != IntRange (11 , 19 )
29
+ assert self .r != WeightedRange .from_mapping (dict .fromkeys (self .r .values , 2 ))
48
30
49
- @pytest .mark .parametrize (
50
- "values" ,
51
- [
52
- np .array (values , dtype = object ),
53
- range (0 , 10 ),
54
- range (10 , 21 ),
55
- range (11 , 20 ),
56
- range (10 , 20 , 2 ),
57
- ],
58
- )
59
- def test_not_equal (self , values ):
60
- r = InclusiveRange .factory (values )
61
- assert self .r != r
62
- assert not self .r .equal_values (r )
31
+ def test_equal_values (self ):
32
+ assert self .r .equal_values (IntRange (10 , 19 ))
33
+ assert not self .r .equal_values (IntRange (10 , 20 ))
34
+ assert not self .r .equal_values (IntRange (11 , 19 ))
35
+ assert self .r .equal_values (
36
+ WeightedRange .from_mapping (dict .fromkeys (self .r .values , 2 ))
37
+ )
63
38
64
39
def test_indices (self ):
65
40
np .testing .assert_array_equal (
@@ -83,22 +58,22 @@ def test_indices_error(self, values):
83
58
@pytest .mark .parametrize (
84
59
"k,expected_bounds" ,
85
60
[
86
- (1 , [(10 , 20 )]),
87
- (2 , [(10 , 15 ), (15 , 20 )]),
88
- (3 , [(10 , 14 ), (14 , 17 ), (17 , 20 )]),
89
- (4 , [(10 , 13 ), (13 , 16 ), (16 , 18 ), (18 , 20 )]),
90
- (5 , [(10 , 12 ), (12 , 14 ), (14 , 16 ), (16 , 18 ), (18 , 20 )]),
91
- (6 , [(10 , 12 ), (12 , 14 ), (14 , 16 ), (16 , 18 ), (18 , 19 ), (19 , 20 )]),
92
- (7 , [(10 , 12 ), (12 , 14 ), (14 , 16 )] + [(i , i + 1 ) for i in range (16 , 20 )]),
93
- (8 , [(10 , 12 ), (12 , 14 )] + [(i , i + 1 ) for i in range (14 , 20 )]),
94
- (9 , [(10 , 12 )] + [(i , i + 1 ) for i in range (12 , 20 )]),
95
- (10 , [(i , i + 1 ) for i in range (10 , 20 )]),
61
+ (1 , [(10 , 19 )]),
62
+ (2 , [(10 , 14 ), (15 , 19 )]),
63
+ (3 , [(10 , 13 ), (14 , 16 ), (17 , 19 )]),
64
+ (4 , [(10 , 12 ), (13 , 15 ), (16 , 17 ), (18 , 19 )]),
65
+ (5 , [(10 , 11 ), (12 , 13 ), (14 , 15 ), (16 , 17 ), (18 , 19 )]),
66
+ (6 , [(10 , 11 ), (12 , 13 ), (14 , 15 ), (16 , 17 ), (18 , 18 ), (19 , 19 )]),
67
+ (7 , [(10 , 11 ), (12 , 13 ), (14 , 15 )] + [(i , i ) for i in range (16 , 20 )]),
68
+ (8 , [(10 , 11 ), (12 , 13 )] + [(i , i ) for i in range (14 , 20 )]),
69
+ (9 , [(10 , 11 )] + [(i , i ) for i in range (12 , 20 )]),
70
+ (10 , [(i , i ) for i in range (10 , 20 )]),
96
71
],
97
72
)
98
73
def test_partition_by_count (self , k , expected_bounds ):
99
74
ranges = list (self .r .partition_by_count (k ))
100
75
assert len (ranges ) == k
101
- expected_ranges = [InclusiveRange . factory ( range ( * bs )) for bs in expected_bounds ]
76
+ expected_ranges = [IntRange ( * bounds ) for bounds in expected_bounds ]
102
77
assert ranges == expected_ranges
103
78
104
79
def test_partition_by_count_error (self ):
@@ -110,33 +85,93 @@ def test_partition_by_count_error(self):
110
85
@pytest .mark .parametrize (
111
86
"max_weight,expected_bounds" ,
112
87
[
113
- (1 , [(i , i + 1 ) for i in range (10 , 20 )]),
114
- (2 , [(10 , 12 ), (12 , 14 ), (14 , 16 ), (16 , 18 ), (18 , 20 )]),
115
- (3 , [(10 , 13 ), (13 , 16 ), (16 , 19 ), (19 , 20 )]),
116
- (4 , [(10 , 14 ), (14 , 18 ), (18 , 20 )]),
117
- (5 , [(10 , 15 ), (15 , 20 )]),
118
- (6 , [(10 , 16 ), (16 , 20 )]),
119
- (7 , [(10 , 17 ), (17 , 20 )]),
120
- (8 , [(10 , 18 ), (18 , 20 )]),
121
- (9 , [(10 , 19 ), (19 , 20 )]),
122
- (10 , [(10 , 20 )]),
123
- (11 , [(10 , 20 )]),
88
+ (1 , [(i , i ) for i in range (10 , 20 )]),
89
+ (2 , [(10 , 11 ), (12 , 13 ), (14 , 15 ), (16 , 17 ), (18 , 19 )]),
90
+ (3 , [(10 , 12 ), (13 , 15 ), (16 , 18 ), (19 , 19 )]),
91
+ (4 , [(10 , 13 ), (14 , 17 ), (18 , 19 )]),
92
+ (5 , [(10 , 14 ), (15 , 19 )]),
93
+ (6 , [(10 , 15 ), (16 , 19 )]),
94
+ (7 , [(10 , 16 ), (17 , 19 )]),
95
+ (8 , [(10 , 17 ), (18 , 19 )]),
96
+ (9 , [(10 , 18 ), (19 , 19 )]),
97
+ (10 , [(10 , 19 )]),
98
+ (11 , [(10 , 19 )]),
124
99
],
125
100
)
126
101
def test_partition_by_weight (self , max_weight , expected_bounds ):
127
102
ranges = list (self .r .partition_by_weight (max_weight ))
128
103
assert max (r .weight for r in ranges ) <= max_weight
129
- expected_ranges = [InclusiveRange . factory ( range ( * bs )) for bs in expected_bounds ]
104
+ expected_ranges = [IntRange ( * bounds ) for bounds in expected_bounds ]
130
105
assert ranges == expected_ranges
131
106
132
107
def test_pickle (self ):
133
108
assert pickle .loads (pickle .dumps (self .r )) == self .r
134
109
135
110
111
+ class TestConstrainedPartitionsIntRange :
112
+ r = ConstrainedPartitionsIntRange (10 , 29 , range (1 , 101 , 4 ))
113
+
114
+ @pytest .mark .parametrize (
115
+ "k,expected_bounds" ,
116
+ [
117
+ (1 , [(10 , 29 )]),
118
+ (2 , [(10 , 20 ), (21 , 29 )]),
119
+ (3 , [(10 , 16 ), (17 , 24 ), (25 , 29 )]),
120
+ (4 , [(10 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 29 )]),
121
+ (5 , [(10 , 12 ), (13 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 29 )]),
122
+ (6 , [(10 , 12 ), (13 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 28 ), (29 , 29 )]),
123
+ ],
124
+ )
125
+ def test_partition_by_count (self , k , expected_bounds ):
126
+ ranges = list (self .r .partition_by_count (k ))
127
+ assert len (ranges ) == k
128
+ # all partitions after the first must start at a start_offset
129
+ start_offsets = self .r .start_offsets
130
+ assert all (r .min in start_offsets for r in ranges [1 :])
131
+ bounds = [(r .min , r .max ) for r in ranges ]
132
+ assert bounds == expected_bounds
133
+
134
+ @pytest .mark .parametrize ("k" , [7 , 8 , 9 , 10 ])
135
+ def test_partition_by_count_error (self , k ):
136
+ with pytest .raises (ValueError ) as excinfo :
137
+ list (self .r .partition_by_count (k ))
138
+ assert "Cannot partition range" in str (excinfo .value )
139
+
140
+ @pytest .mark .parametrize (
141
+ "max_weight,expected_bounds" ,
142
+ [
143
+ (4 , [(10 , 12 ), (13 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 28 ), (29 , 29 )]),
144
+ (5 , [(10 , 12 ), (13 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 29 )]),
145
+ (6 , [(10 , 12 ), (13 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 29 )]),
146
+ (7 , [(10 , 16 ), (17 , 20 ), (21 , 24 ), (25 , 29 )]),
147
+ (8 , [(10 , 16 ), (17 , 24 ), (25 , 29 )]),
148
+ (9 , [(10 , 16 ), (17 , 24 ), (25 , 29 )]),
149
+ (10 , [(10 , 16 ), (17 , 24 ), (25 , 29 )]),
150
+ (11 , [(10 , 20 ), (21 , 29 )]),
151
+ ],
152
+ )
153
+ def test_partition_by_weight (self , max_weight , expected_bounds ):
154
+ ranges = list (self .r .partition_by_weight (max_weight ))
155
+ assert max (r .weight for r in ranges ) <= max_weight
156
+ # all partitions after the first must start at a start_offset
157
+ start_offsets = self .r .start_offsets
158
+ assert all (r .min in start_offsets for r in ranges [1 :])
159
+ bounds = [(r .min , r .max ) for r in ranges ]
160
+ assert bounds == expected_bounds
161
+
162
+ @pytest .mark .parametrize ("max_weight" , [1 , 2 , 3 ])
163
+ def test_partition_by_weight_error (self , max_weight ):
164
+ with pytest .raises (ValueError ) as excinfo :
165
+ list (self .r .partition_by_weight (max_weight ))
166
+ assert "Cannot partition range" in str (excinfo .value )
167
+
168
+
136
169
class TestWeightedRange :
137
170
values = ("e" , "f" , "a" , "d" , "a" , "c" , "d" , "a" , "f" , "c" , "f" , "f" , "b" , "d" )
138
- r = InclusiveRange .factory (values )
139
- r2 = InclusiveRange .factory ({v : timedelta (c ) for v , c in Counter (values ).items ()})
171
+ r = WeightedRange .from_mapping (Counter (values ))
172
+ r2 = WeightedRange .from_mapping (
173
+ {v : timedelta (c ) for v , c in Counter (values ).items ()}
174
+ )
140
175
141
176
@pytest .mark .parametrize ("r" , [r , r2 ])
142
177
def test_basic (self , r ):
@@ -145,44 +180,19 @@ def test_basic(self, r):
145
180
assert len (r ) == 6
146
181
assert r .weight == 14 if r is self .r else timedelta (14 )
147
182
148
- @pytest .mark .parametrize (
149
- "values" ,
150
- [
151
- values ,
152
- list (values ),
153
- iter (values ),
154
- reversed (values ),
155
- Counter (values ),
156
- np .array (values ),
157
- np .array (values , dtype = object ),
158
- ],
159
- )
160
- def test_equal (self , values ):
161
- assert_equal_ranges (self .r , InclusiveRange .factory (values ), WeightedRange )
162
-
163
- def test_not_equal (self ):
164
- assert self .r != InclusiveRange .factory (set (self .values ))
165
- assert self .r != InclusiveRange .factory (range (len (set (self .values ))))
183
+ def test_equal (self ):
184
+ assert_equal_ranges (self .r , WeightedRange .from_mapping (Counter (self .values )))
185
+ assert self .r != WeightedRange .from_mapping (Counter (set (self .values )))
186
+ assert self .r != IntRange (0 , len (set (self .values )) - 1 )
166
187
167
188
def test_equal_values (self ):
168
- assert self .r .equal_values (InclusiveRange .factory (set (self .values )))
169
-
170
- r = InclusiveRange .factory ([1 , 2 , 3 , 3 , 4 , 5 ])
171
- assert r .equal_values (InclusiveRange .factory (range (1 , 6 )))
172
- assert not r .equal_values (InclusiveRange .factory (range (1 , 7 )))
173
- assert not r .equal_values (InclusiveRange .factory (range (2 , 7 )))
174
-
175
- def test_strided_range (self ):
176
- assert_equal_ranges (
177
- InclusiveRange .factory (range (10 , 20 , 3 )),
178
- InclusiveRange .factory ([10 , 13 , 16 , 19 ]),
179
- WeightedRange ,
180
- )
181
- assert_equal_ranges (
182
- InclusiveRange .factory (range (20 , 10 , - 3 )),
183
- InclusiveRange .factory ([11 , 14 , 17 , 20 ]),
184
- WeightedRange ,
189
+ assert self .r .equal_values (
190
+ WeightedRange .from_mapping (Counter (set (self .values )))
185
191
)
192
+ r = WeightedRange .from_mapping (Counter ([1 , 2 , 3 , 3 , 4 , 5 ]))
193
+ assert r .equal_values (IntRange (1 , 5 ))
194
+ assert not r .equal_values (IntRange (1 , 6 ))
195
+ assert not r .equal_values (IntRange (2 , 5 ))
186
196
187
197
def test_indices (self ):
188
198
np .testing .assert_array_equal (
@@ -219,15 +229,15 @@ def test_indices_error(self, values):
219
229
def test_partition_by_count (self , k , expected_mappings ):
220
230
ranges = list (self .r .partition_by_count (k ))
221
231
assert len (ranges ) == k
222
- expected_ranges = list (map (InclusiveRange . factory , expected_mappings ))
232
+ expected_ranges = list (map (WeightedRange . from_mapping , expected_mappings ))
223
233
assert ranges == expected_ranges
224
234
225
235
@parametrize_by_count
226
236
def test_partition_by_count2 (self , k , expected_mappings ):
227
237
ranges = list (self .r2 .partition_by_count (k ))
228
238
assert len (ranges ) == k
229
239
expected_ranges = [
230
- InclusiveRange . factory ({v : timedelta (w ) for v , w in mapping .items ()})
240
+ WeightedRange . from_mapping ({v : timedelta (w ) for v , w in mapping .items ()})
231
241
for mapping in expected_mappings
232
242
]
233
243
assert ranges == expected_ranges
@@ -261,7 +271,7 @@ def test_partition_by_count_error(self, r):
261
271
def test_partition_by_weight (self , max_weight , expected_mappings ):
262
272
ranges = list (self .r .partition_by_weight (max_weight ))
263
273
assert max (r .weight for r in ranges ) <= max_weight
264
- expected_ranges = list (map (InclusiveRange . factory , expected_mappings ))
274
+ expected_ranges = list (map (WeightedRange . from_mapping , expected_mappings ))
265
275
assert ranges == expected_ranges
266
276
267
277
@parametrize_by_max_weight
@@ -270,7 +280,7 @@ def test_partition_by_weight2(self, max_weight, expected_mappings):
270
280
ranges = list (self .r2 .partition_by_weight (max_weight ))
271
281
assert max (r .weight for r in ranges ) <= max_weight
272
282
expected_ranges = [
273
- InclusiveRange . factory ({v : timedelta (w ) for v , w in mapping .items ()})
283
+ WeightedRange . from_mapping ({v : timedelta (w ) for v , w in mapping .items ()})
274
284
for mapping in expected_mappings
275
285
]
276
286
assert ranges == expected_ranges
@@ -280,17 +290,16 @@ def test_partition_by_weight2(self, max_weight, expected_mappings):
280
290
)
281
291
def test_partition_by_weight_error (self , r , max_weights ):
282
292
for max_weight in max_weights :
283
- with pytest .raises (ValueError ):
293
+ with pytest .raises (ValueError ) as excinfo :
284
294
list (r .partition_by_weight (max_weight ))
295
+ assert "Cannot partition range" in str (excinfo .value )
285
296
286
297
def test_pickle (self ):
287
298
assert pickle .loads (pickle .dumps (self .r )) == self .r
288
299
assert pickle .loads (pickle .dumps (self .r2 )) == self .r2
289
300
290
301
291
- def assert_equal_ranges (r1 , r2 , cls ):
292
- assert isinstance (r1 , cls )
293
- assert isinstance (r2 , cls )
302
+ def assert_equal_ranges (r1 , r2 ):
294
303
assert r1 .min == r2 .min
295
304
assert r1 .max == r2 .max
296
305
assert r1 .weight == r2 .weight
0 commit comments