Skip to content

Commit

Permalink
Fix view_attr not being respected by __getitem__ and subarray (#…
Browse files Browse the repository at this point in the history
…2139)

Co-authored-by: nguyenv <[email protected]>
  • Loading branch information
kounelisagis and nguyenv authored Jan 28, 2025
1 parent 92fb0e2 commit 9638128
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
5 changes: 3 additions & 2 deletions tiledb/dense_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def __getitem__(self, selection):
"""
if self.view_attr:
result = self.subarray(selection, attrs=(self.view_attr,))
return result[self.view_attr]
return self.subarray(selection)

result = self.subarray(selection)
for i in range(self.schema.nattr):
Expand Down Expand Up @@ -291,6 +290,8 @@ def subarray(self, selection, attrs=None, cond=None, coords=False, order=None):
attr = self.schema.attr(0)
if attr.isanon:
return out[attr._internal_name]
if self.view_attr is not None:
return out[self.view_attr]
return out

def _read_dense_subarray(
Expand Down
10 changes: 9 additions & 1 deletion tiledb/sparse_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections import OrderedDict

import numpy as np
Expand Down Expand Up @@ -292,6 +293,9 @@ def __getitem__(self, selection):
>>> # A[5.0:579.9]
"""
if self.view_attr is not None:
return self.subarray(selection)

result = self.subarray(selection)
for i in range(self.schema.nattr):
attr = self.schema.attr(i)
Expand Down Expand Up @@ -518,7 +522,11 @@ def subarray(self, selection, coords=True, attrs=None, cond=None, order=None):

attr_names = list()

if attrs is None:
if self.view_attr is not None:
if attrs is not None:
warnings.warn("view_attr is set, ignoring attrs parameter", UserWarning)
attr_names.extend(self.view_attr)
elif attrs is None:
attr_names.extend(
self.schema.attr(i)._internal_name for i in range(self.schema.nattr)
)
Expand Down
47 changes: 45 additions & 2 deletions tiledb/tests/test_libtiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .common import (
DiskTestCase,
assert_captured,
assert_dict_arrays_equal,
assert_subarrays_equal,
assert_unordered_equal,
fx_sparse_cell_order, # noqa: F401
Expand Down Expand Up @@ -923,8 +924,8 @@ def assert_ts(timestamp, result):
assert_ts((timestamps[2], None), A * 3)
assert_ts((timestamps[2], None), A * 3)

def test_open_attr(self):
uri = self.path("test_open_attr")
def test_open_attr_dense(self):
uri = self.path("test_open_attr_dense")
schema = tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(name="dim0", dtype=np.uint32, domain=(1, 4))
Expand All @@ -949,6 +950,48 @@ def test_open_attr(self):
assert_array_equal(A[:], np.array((1, 2, 3, 4)))
assert list(A.multi_index[:].keys()) == ["x"]

with tiledb.open(uri, attr="x") as A:
q = A.query(cond="x <= 3")
expected = np.array([1, 2, 3, schema.attr("x").fill[0]])
assert_array_equal(q[:], expected)

def test_open_attr_sparse(self):
uri = self.path("test_open_attr_sparse")
schema = tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(name="dim0", dtype=np.uint32, domain=(1, 4))
),
attrs=(
tiledb.Attr(name="x", dtype=np.int32),
tiledb.Attr(name="y", dtype=np.int32),
),
sparse=True,
)
tiledb.Array.create(uri, schema)

with tiledb.open(uri, mode="w") as A:
A[[1, 2, 3, 4]] = {"x": np.array((1, 2, 3, 4)), "y": np.array((5, 6, 7, 8))}

with self.assertRaises(KeyError):
tiledb.open(uri, attr="z")

with self.assertRaises(KeyError):
tiledb.open(uri, attr="dim0")

with tiledb.open(uri, attr="x") as A:
expected = OrderedDict(
[("dim0", np.array([1, 2, 3, 4])), ("x", np.array([1, 2, 3, 4]))]
)
assert_dict_arrays_equal(A[:], expected)
assert list(A.multi_index[:].keys()) == ["dim0", "x"]

with tiledb.open(uri, attr="x") as A:
q = A.query(cond="x <= 3")
expected = OrderedDict(
[("dim0", np.array([1, 2, 3])), ("x", np.array([1, 2, 3]))]
)
assert_dict_arrays_equal(q[:], expected)

def test_ncell_attributes(self):
dom = tiledb.Domain(tiledb.Dim(domain=(0, 9), tile=10, dtype=int))
attr = tiledb.Attr(dtype=[("", np.int32), ("", np.int32), ("", np.int32)])
Expand Down

0 comments on commit 9638128

Please sign in to comment.