Skip to content

Commit b9baa54

Browse files
authored
Change DenseTensorSchema.key_range to return a ConstrainedPartitionsIntRange (#193)
1 parent 1e7a755 commit b9baa54

File tree

4 files changed

+46
-30
lines changed

4 files changed

+46
-30
lines changed

tests/readers/test_tensor_schema.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ def dense_uri(tmp_path_factory):
1515
sparse=False,
1616
domain=tiledb.Domain(
1717
tiledb.Dim(name="d0", domain=(0, 9999), dtype=np.int32, tile=123),
18-
tiledb.Dim(name="d1", domain=(-2, 2), dtype=np.int32, tile=2),
19-
tiledb.Dim(name="d2", domain=(1, 2), dtype=np.int32, tile=1),
18+
tiledb.Dim(name="d1", domain=(-20, 19), dtype=np.int32, tile=4),
2019
),
2120
attrs=[
2221
tiledb.Attr(name="af8", dtype=np.float64),
@@ -90,21 +89,31 @@ def parametrize_fields(*fields, num=3):
9089
@pytest.mark.parametrize(
9190
"key_dim,memory_budget,dim_selectors",
9291
[
93-
("d0", 16_000, {}),
94-
("d0", 32_000, {}),
95-
("d0", 64_000, {}),
96-
("d1", 500_000, {}),
97-
("d1", 600_000, {}),
92+
("d0", 160_000, {}),
93+
("d0", 160_000, {"d0": slice(1000, 9000)}),
94+
("d0", 160_000, {"d0": slice(None, 9000)}),
95+
("d0", 160_000, {"d0": slice(1000, None)}),
96+
("d0", 160_000, {"d1": slice(-10, 10)}),
97+
("d0", 160_000, {"d1": slice(-10, None)}),
98+
("d0", 160_000, {"d1": slice(None, 10)}),
99+
("d0", 160_000, {"d1": list(range(-10, 10, 3))}),
100+
("d0", 160_000, {"d0": slice(1000, 9000), "d1": slice(-10, 10)}),
101+
("d0", 160_000, {"d0": slice(None, 9000), "d1": slice(-10, None)}),
102+
("d0", 160_000, {"d0": slice(1000, None), "d1": slice(None, 10)}),
98103
("d1", 700_000, {}),
99-
("d0", 16_000, {"d1": slice(0, 2)}),
100-
("d0", 32_000, {"d1": slice(None, 0)}),
101-
("d0", 64_000, {"d1": slice(-1, None), "d2": [1]}),
102-
("d1", 500_000, {"d0": [1, 2, 3]}),
103-
("d1", 600_000, {"d0": [1, 100, 143, 976], "d2": slice(2, 2)}),
104-
("d1", 700_000, {"d0": [1, 100, 143, 1093, 1094]}),
104+
("d1", 700_000, {"d1": slice(-10, 10)}),
105+
("d1", 700_000, {"d1": slice(None, 10)}),
106+
("d1", 700_000, {"d1": slice(-10, None)}),
107+
("d1", 700_000, {"d0": slice(1000, 9000)}),
108+
("d1", 700_000, {"d0": slice(None, 9000)}),
109+
("d1", 700_000, {"d0": slice(1000, None)}),
110+
("d1", 700_000, {"d0": list(range(100, 5000, 3))}),
111+
("d1", 700_000, {"d1": slice(-10, 10), "d0": slice(1000, 9000)}),
112+
("d1", 700_000, {"d1": slice(None, 10), "d0": slice(None, 9000)}),
113+
("d1", 700_000, {"d1": slice(-10, None), "d0": slice(1000, None)}),
105114
],
106115
)
107-
@parametrize_fields("d0", "d1", "d2", "af8", "af4", "au1")
116+
@parametrize_fields("d0", "d1", "af8", "af4", "au1")
108117
def test_max_partition_weight_dense(
109118
dense_uri, fields, key_dim, memory_budget, dim_selectors
110119
):

tests/readers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def dim(self):
153153
return tiledb.Dim(
154154
name=self.name,
155155
domain=(self.min_value, self(self.size - 1)),
156-
tile=np.random.randint(1, self.size + 1),
156+
tile=np.random.randint(1, self.size),
157157
dtype=self.dtype,
158158
)
159159

tiledb/ml/readers/_tensor_schema/dense.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from .base import TensorSchema
8-
from .ranges import InclusiveRange, IntRange
8+
from .ranges import ConstrainedPartitionsIntRange, InclusiveRange
99

1010

1111
class DenseTensorSchema(TensorSchema[np.ndarray]):
@@ -21,19 +21,26 @@ def __init__(self, **kwargs: Any):
2121
)
2222

2323
@property
24-
def key_range(self) -> IntRange:
24+
def key_range(self) -> ConstrainedPartitionsIntRange:
25+
self._key_range: ConstrainedPartitionsIntRange
2526
try:
26-
key_dim_slice = self._dim_selectors[0]
27-
except KeyError:
27+
return self._key_range
28+
except AttributeError:
2829
key_dim_min, key_dim_max = self._ned[0]
29-
else:
30-
assert isinstance(key_dim_slice, slice)
31-
key_dim_min = key_dim_slice.start
32-
key_dim_max = key_dim_slice.stop
33-
raise NotImplementedError(
34-
"Key dimension slicing is not yet implemented for dense arrays"
30+
key_dim_slice = self._dim_selectors.get(0)
31+
if key_dim_slice is not None:
32+
assert isinstance(key_dim_slice, slice)
33+
min_key = key_dim_slice.start
34+
max_key = key_dim_slice.stop
35+
else:
36+
min_key = key_dim_min
37+
max_key = key_dim_max
38+
key_dim_tile = self._array.dim(self.key_dim).tile
39+
start_offsets = range(key_dim_min, key_dim_max + 1, key_dim_tile)
40+
self._key_range = ConstrainedPartitionsIntRange(
41+
min_key, max_key, start_offsets
3542
)
36-
return IntRange(key_dim_min, key_dim_max)
43+
return self._key_range
3744

3845
def iter_tensors(
3946
self, key_ranges: Iterable[InclusiveRange[int, int]]

tiledb/ml/readers/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class ArrayParams:
3030
- array: TileDB array to be accessed
3131
- key_dim: Name (or index) of the array key dimension. Defaults to the first dimension.
3232
- fields: Fields (dimensions and attributes) to be retrieved from array. Defaults to
33-
all attributes of the array.
33+
all attributes of the array.
3434
- dim_selectors: Mapping from dimension name to a slice or sequence of indices of this
35-
dimension to select. Currently implemented only for non-key dimensions of dense arrays
36-
- tensor_kind: kind of tensor desired. If not specified, the default tensor kind is
37-
determined based on the array schema.
35+
dimension to select.
36+
- tensor_kind: kind of tensor desired. If not specified, it is determined based on the
37+
array schema.
3838
"""
3939

4040
array: tiledb.Array

0 commit comments

Comments
 (0)