Skip to content

Commit 1e7a755

Browse files
authored
Add ConstrainedPartitionsIntRange (#192)
* Fix InclusiveRange.equal_values(x, y) between IntRange and WeightedRange * Drop InclusiveRange.factory, add WeightedRange.from_mapping * Add ConstrainedPartitionsIntRange * Refactor ConstrainedPartitionsIntRange * Update InclusiveRange.__getstate__ to work with inherited __slots__
1 parent d2c139b commit 1e7a755

File tree

4 files changed

+221
-189
lines changed

4 files changed

+221
-189
lines changed

tests/readers/test_ranges.py

Lines changed: 117 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -6,60 +6,35 @@
66
import pytest
77

88
from tiledb.ml.readers._tensor_schema.ranges import (
9-
InclusiveRange,
9+
ConstrainedPartitionsIntRange,
1010
IntRange,
1111
WeightedRange,
1212
)
1313

1414

15-
@pytest.mark.parametrize("values", [None, 42, 3.14])
16-
def test_inclusive_range_factory_type_error(values):
17-
with pytest.raises(TypeError) as excinfo:
18-
InclusiveRange.factory(values)
19-
assert "Cannot create inclusive range" in str(excinfo.value)
20-
21-
2215
class TestIntRange:
23-
values = range(10, 20)
24-
r = InclusiveRange.factory(values)
16+
r = IntRange(10, 19)
2517

2618
def test_basic(self):
2719
assert self.r.min == 10
2820
assert self.r.max == 19
2921
assert self.r.weight == 10
3022
assert len(self.r) == 10
3123

32-
@pytest.mark.parametrize(
33-
"values",
34-
[
35-
values,
36-
list(values),
37-
set(values),
38-
iter(values),
39-
reversed(values),
40-
Counter(values),
41-
np.array(values),
42-
range(19, 9, -1),
43-
np.arange(19, 9, -1),
44-
],
45-
)
46-
def test_equal(self, values):
47-
assert_equal_ranges(self.r, InclusiveRange.factory(values), IntRange)
24+
def test_equal(self):
25+
assert_equal_ranges(self.r, IntRange(10, 19))
26+
assert self.r != IntRange(0, 9)
27+
assert self.r != IntRange(10, 20)
28+
assert self.r != IntRange(11, 19)
29+
assert self.r != WeightedRange.from_mapping(dict.fromkeys(self.r.values, 2))
4830

49-
@pytest.mark.parametrize(
50-
"values",
51-
[
52-
np.array(values, dtype=object),
53-
range(0, 10),
54-
range(10, 21),
55-
range(11, 20),
56-
range(10, 20, 2),
57-
],
58-
)
59-
def test_not_equal(self, values):
60-
r = InclusiveRange.factory(values)
61-
assert self.r != r
62-
assert not self.r.equal_values(r)
31+
def test_equal_values(self):
32+
assert self.r.equal_values(IntRange(10, 19))
33+
assert not self.r.equal_values(IntRange(10, 20))
34+
assert not self.r.equal_values(IntRange(11, 19))
35+
assert self.r.equal_values(
36+
WeightedRange.from_mapping(dict.fromkeys(self.r.values, 2))
37+
)
6338

6439
def test_indices(self):
6540
np.testing.assert_array_equal(
@@ -83,22 +58,22 @@ def test_indices_error(self, values):
8358
@pytest.mark.parametrize(
8459
"k,expected_bounds",
8560
[
86-
(1, [(10, 20)]),
87-
(2, [(10, 15), (15, 20)]),
88-
(3, [(10, 14), (14, 17), (17, 20)]),
89-
(4, [(10, 13), (13, 16), (16, 18), (18, 20)]),
90-
(5, [(10, 12), (12, 14), (14, 16), (16, 18), (18, 20)]),
91-
(6, [(10, 12), (12, 14), (14, 16), (16, 18), (18, 19), (19, 20)]),
92-
(7, [(10, 12), (12, 14), (14, 16)] + [(i, i + 1) for i in range(16, 20)]),
93-
(8, [(10, 12), (12, 14)] + [(i, i + 1) for i in range(14, 20)]),
94-
(9, [(10, 12)] + [(i, i + 1) for i in range(12, 20)]),
95-
(10, [(i, i + 1) for i in range(10, 20)]),
61+
(1, [(10, 19)]),
62+
(2, [(10, 14), (15, 19)]),
63+
(3, [(10, 13), (14, 16), (17, 19)]),
64+
(4, [(10, 12), (13, 15), (16, 17), (18, 19)]),
65+
(5, [(10, 11), (12, 13), (14, 15), (16, 17), (18, 19)]),
66+
(6, [(10, 11), (12, 13), (14, 15), (16, 17), (18, 18), (19, 19)]),
67+
(7, [(10, 11), (12, 13), (14, 15)] + [(i, i) for i in range(16, 20)]),
68+
(8, [(10, 11), (12, 13)] + [(i, i) for i in range(14, 20)]),
69+
(9, [(10, 11)] + [(i, i) for i in range(12, 20)]),
70+
(10, [(i, i) for i in range(10, 20)]),
9671
],
9772
)
9873
def test_partition_by_count(self, k, expected_bounds):
9974
ranges = list(self.r.partition_by_count(k))
10075
assert len(ranges) == k
101-
expected_ranges = [InclusiveRange.factory(range(*bs)) for bs in expected_bounds]
76+
expected_ranges = [IntRange(*bounds) for bounds in expected_bounds]
10277
assert ranges == expected_ranges
10378

10479
def test_partition_by_count_error(self):
@@ -110,33 +85,93 @@ def test_partition_by_count_error(self):
11085
@pytest.mark.parametrize(
11186
"max_weight,expected_bounds",
11287
[
113-
(1, [(i, i + 1) for i in range(10, 20)]),
114-
(2, [(10, 12), (12, 14), (14, 16), (16, 18), (18, 20)]),
115-
(3, [(10, 13), (13, 16), (16, 19), (19, 20)]),
116-
(4, [(10, 14), (14, 18), (18, 20)]),
117-
(5, [(10, 15), (15, 20)]),
118-
(6, [(10, 16), (16, 20)]),
119-
(7, [(10, 17), (17, 20)]),
120-
(8, [(10, 18), (18, 20)]),
121-
(9, [(10, 19), (19, 20)]),
122-
(10, [(10, 20)]),
123-
(11, [(10, 20)]),
88+
(1, [(i, i) for i in range(10, 20)]),
89+
(2, [(10, 11), (12, 13), (14, 15), (16, 17), (18, 19)]),
90+
(3, [(10, 12), (13, 15), (16, 18), (19, 19)]),
91+
(4, [(10, 13), (14, 17), (18, 19)]),
92+
(5, [(10, 14), (15, 19)]),
93+
(6, [(10, 15), (16, 19)]),
94+
(7, [(10, 16), (17, 19)]),
95+
(8, [(10, 17), (18, 19)]),
96+
(9, [(10, 18), (19, 19)]),
97+
(10, [(10, 19)]),
98+
(11, [(10, 19)]),
12499
],
125100
)
126101
def test_partition_by_weight(self, max_weight, expected_bounds):
127102
ranges = list(self.r.partition_by_weight(max_weight))
128103
assert max(r.weight for r in ranges) <= max_weight
129-
expected_ranges = [InclusiveRange.factory(range(*bs)) for bs in expected_bounds]
104+
expected_ranges = [IntRange(*bounds) for bounds in expected_bounds]
130105
assert ranges == expected_ranges
131106

132107
def test_pickle(self):
133108
assert pickle.loads(pickle.dumps(self.r)) == self.r
134109

135110

111+
class TestConstrainedPartitionsIntRange:
112+
r = ConstrainedPartitionsIntRange(10, 29, range(1, 101, 4))
113+
114+
@pytest.mark.parametrize(
115+
"k,expected_bounds",
116+
[
117+
(1, [(10, 29)]),
118+
(2, [(10, 20), (21, 29)]),
119+
(3, [(10, 16), (17, 24), (25, 29)]),
120+
(4, [(10, 16), (17, 20), (21, 24), (25, 29)]),
121+
(5, [(10, 12), (13, 16), (17, 20), (21, 24), (25, 29)]),
122+
(6, [(10, 12), (13, 16), (17, 20), (21, 24), (25, 28), (29, 29)]),
123+
],
124+
)
125+
def test_partition_by_count(self, k, expected_bounds):
126+
ranges = list(self.r.partition_by_count(k))
127+
assert len(ranges) == k
128+
# all partitions after the first must start at a start_offset
129+
start_offsets = self.r.start_offsets
130+
assert all(r.min in start_offsets for r in ranges[1:])
131+
bounds = [(r.min, r.max) for r in ranges]
132+
assert bounds == expected_bounds
133+
134+
@pytest.mark.parametrize("k", [7, 8, 9, 10])
135+
def test_partition_by_count_error(self, k):
136+
with pytest.raises(ValueError) as excinfo:
137+
list(self.r.partition_by_count(k))
138+
assert "Cannot partition range" in str(excinfo.value)
139+
140+
@pytest.mark.parametrize(
141+
"max_weight,expected_bounds",
142+
[
143+
(4, [(10, 12), (13, 16), (17, 20), (21, 24), (25, 28), (29, 29)]),
144+
(5, [(10, 12), (13, 16), (17, 20), (21, 24), (25, 29)]),
145+
(6, [(10, 12), (13, 16), (17, 20), (21, 24), (25, 29)]),
146+
(7, [(10, 16), (17, 20), (21, 24), (25, 29)]),
147+
(8, [(10, 16), (17, 24), (25, 29)]),
148+
(9, [(10, 16), (17, 24), (25, 29)]),
149+
(10, [(10, 16), (17, 24), (25, 29)]),
150+
(11, [(10, 20), (21, 29)]),
151+
],
152+
)
153+
def test_partition_by_weight(self, max_weight, expected_bounds):
154+
ranges = list(self.r.partition_by_weight(max_weight))
155+
assert max(r.weight for r in ranges) <= max_weight
156+
# all partitions after the first must start at a start_offset
157+
start_offsets = self.r.start_offsets
158+
assert all(r.min in start_offsets for r in ranges[1:])
159+
bounds = [(r.min, r.max) for r in ranges]
160+
assert bounds == expected_bounds
161+
162+
@pytest.mark.parametrize("max_weight", [1, 2, 3])
163+
def test_partition_by_weight_error(self, max_weight):
164+
with pytest.raises(ValueError) as excinfo:
165+
list(self.r.partition_by_weight(max_weight))
166+
assert "Cannot partition range" in str(excinfo.value)
167+
168+
136169
class TestWeightedRange:
137170
values = ("e", "f", "a", "d", "a", "c", "d", "a", "f", "c", "f", "f", "b", "d")
138-
r = InclusiveRange.factory(values)
139-
r2 = InclusiveRange.factory({v: timedelta(c) for v, c in Counter(values).items()})
171+
r = WeightedRange.from_mapping(Counter(values))
172+
r2 = WeightedRange.from_mapping(
173+
{v: timedelta(c) for v, c in Counter(values).items()}
174+
)
140175

141176
@pytest.mark.parametrize("r", [r, r2])
142177
def test_basic(self, r):
@@ -145,44 +180,19 @@ def test_basic(self, r):
145180
assert len(r) == 6
146181
assert r.weight == 14 if r is self.r else timedelta(14)
147182

148-
@pytest.mark.parametrize(
149-
"values",
150-
[
151-
values,
152-
list(values),
153-
iter(values),
154-
reversed(values),
155-
Counter(values),
156-
np.array(values),
157-
np.array(values, dtype=object),
158-
],
159-
)
160-
def test_equal(self, values):
161-
assert_equal_ranges(self.r, InclusiveRange.factory(values), WeightedRange)
162-
163-
def test_not_equal(self):
164-
assert self.r != InclusiveRange.factory(set(self.values))
165-
assert self.r != InclusiveRange.factory(range(len(set(self.values))))
183+
def test_equal(self):
184+
assert_equal_ranges(self.r, WeightedRange.from_mapping(Counter(self.values)))
185+
assert self.r != WeightedRange.from_mapping(Counter(set(self.values)))
186+
assert self.r != IntRange(0, len(set(self.values)) - 1)
166187

167188
def test_equal_values(self):
168-
assert self.r.equal_values(InclusiveRange.factory(set(self.values)))
169-
170-
r = InclusiveRange.factory([1, 2, 3, 3, 4, 5])
171-
assert r.equal_values(InclusiveRange.factory(range(1, 6)))
172-
assert not r.equal_values(InclusiveRange.factory(range(1, 7)))
173-
assert not r.equal_values(InclusiveRange.factory(range(2, 7)))
174-
175-
def test_strided_range(self):
176-
assert_equal_ranges(
177-
InclusiveRange.factory(range(10, 20, 3)),
178-
InclusiveRange.factory([10, 13, 16, 19]),
179-
WeightedRange,
180-
)
181-
assert_equal_ranges(
182-
InclusiveRange.factory(range(20, 10, -3)),
183-
InclusiveRange.factory([11, 14, 17, 20]),
184-
WeightedRange,
189+
assert self.r.equal_values(
190+
WeightedRange.from_mapping(Counter(set(self.values)))
185191
)
192+
r = WeightedRange.from_mapping(Counter([1, 2, 3, 3, 4, 5]))
193+
assert r.equal_values(IntRange(1, 5))
194+
assert not r.equal_values(IntRange(1, 6))
195+
assert not r.equal_values(IntRange(2, 5))
186196

187197
def test_indices(self):
188198
np.testing.assert_array_equal(
@@ -219,15 +229,15 @@ def test_indices_error(self, values):
219229
def test_partition_by_count(self, k, expected_mappings):
220230
ranges = list(self.r.partition_by_count(k))
221231
assert len(ranges) == k
222-
expected_ranges = list(map(InclusiveRange.factory, expected_mappings))
232+
expected_ranges = list(map(WeightedRange.from_mapping, expected_mappings))
223233
assert ranges == expected_ranges
224234

225235
@parametrize_by_count
226236
def test_partition_by_count2(self, k, expected_mappings):
227237
ranges = list(self.r2.partition_by_count(k))
228238
assert len(ranges) == k
229239
expected_ranges = [
230-
InclusiveRange.factory({v: timedelta(w) for v, w in mapping.items()})
240+
WeightedRange.from_mapping({v: timedelta(w) for v, w in mapping.items()})
231241
for mapping in expected_mappings
232242
]
233243
assert ranges == expected_ranges
@@ -261,7 +271,7 @@ def test_partition_by_count_error(self, r):
261271
def test_partition_by_weight(self, max_weight, expected_mappings):
262272
ranges = list(self.r.partition_by_weight(max_weight))
263273
assert max(r.weight for r in ranges) <= max_weight
264-
expected_ranges = list(map(InclusiveRange.factory, expected_mappings))
274+
expected_ranges = list(map(WeightedRange.from_mapping, expected_mappings))
265275
assert ranges == expected_ranges
266276

267277
@parametrize_by_max_weight
@@ -270,7 +280,7 @@ def test_partition_by_weight2(self, max_weight, expected_mappings):
270280
ranges = list(self.r2.partition_by_weight(max_weight))
271281
assert max(r.weight for r in ranges) <= max_weight
272282
expected_ranges = [
273-
InclusiveRange.factory({v: timedelta(w) for v, w in mapping.items()})
283+
WeightedRange.from_mapping({v: timedelta(w) for v, w in mapping.items()})
274284
for mapping in expected_mappings
275285
]
276286
assert ranges == expected_ranges
@@ -280,17 +290,16 @@ def test_partition_by_weight2(self, max_weight, expected_mappings):
280290
)
281291
def test_partition_by_weight_error(self, r, max_weights):
282292
for max_weight in max_weights:
283-
with pytest.raises(ValueError):
293+
with pytest.raises(ValueError) as excinfo:
284294
list(r.partition_by_weight(max_weight))
295+
assert "Cannot partition range" in str(excinfo.value)
285296

286297
def test_pickle(self):
287298
assert pickle.loads(pickle.dumps(self.r)) == self.r
288299
assert pickle.loads(pickle.dumps(self.r2)) == self.r2
289300

290301

291-
def assert_equal_ranges(r1, r2, cls):
292-
assert isinstance(r1, cls)
293-
assert isinstance(r2, cls)
302+
def assert_equal_ranges(r1, r2):
294303
assert r1.min == r2.min
295304
assert r1.max == r2.max
296305
assert r1.weight == r2.weight

tiledb/ml/readers/_tensor_schema/base_sparse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from .base import Tensor, TensorSchema
7-
from .ranges import InclusiveRange
7+
from .ranges import WeightedRange
88

99

1010
class BaseSparseTensorSchema(TensorSchema[Tensor]):
@@ -18,8 +18,8 @@ def __init__(self, **kwargs: Any):
1818
)
1919

2020
@property
21-
def key_range(self) -> InclusiveRange[Any, int]:
22-
self._key_range: InclusiveRange[Any, int]
21+
def key_range(self) -> WeightedRange[Any, int]:
22+
self._key_range: WeightedRange[Any, int]
2323
try:
2424
return self._key_range
2525
except AttributeError:
@@ -30,7 +30,7 @@ def key_range(self) -> InclusiveRange[Any, int]:
3030
assert isinstance(key_dim_slice, slice)
3131
for result in query[key_dim_slice]:
3232
key_counter.update(result[key_dim])
33-
self._key_range = InclusiveRange.factory(key_counter)
33+
self._key_range = WeightedRange.from_mapping(key_counter)
3434
return self._key_range
3535

3636
@property

tiledb/ml/readers/_tensor_schema/dense.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from .base import TensorSchema
8-
from .ranges import InclusiveRange
8+
from .ranges import InclusiveRange, IntRange
99

1010

1111
class DenseTensorSchema(TensorSchema[np.ndarray]):
@@ -21,7 +21,7 @@ def __init__(self, **kwargs: Any):
2121
)
2222

2323
@property
24-
def key_range(self) -> InclusiveRange[int, int]:
24+
def key_range(self) -> IntRange:
2525
try:
2626
key_dim_slice = self._dim_selectors[0]
2727
except KeyError:
@@ -33,7 +33,7 @@ def key_range(self) -> InclusiveRange[int, int]:
3333
raise NotImplementedError(
3434
"Key dimension slicing is not yet implemented for dense arrays"
3535
)
36-
return InclusiveRange.factory(range(key_dim_min, key_dim_max + 1))
36+
return IntRange(key_dim_min, key_dim_max)
3737

3838
def iter_tensors(
3939
self, key_ranges: Iterable[InclusiveRange[int, int]]

0 commit comments

Comments
 (0)