Skip to content

Commit 1ffc9b9

Browse files
committed
Fix fill value for complex attributes
1 parent 521dcbf commit 1ffc9b9

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

tiledb/attribute.py

+13
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ def __init__(
8383
if fill is not None:
8484
if self._tiledb_dtype == lt.DataType.STRING_UTF8:
8585
self._fill = np.array([fill.encode("utf-8")], dtype="S")
86+
elif self.dtype == np.dtype("complex64") or self.dtype == np.dtype(
87+
"complex128"
88+
):
89+
if hasattr(fill, "dtype") and fill.dtype in {
90+
np.dtype("f4, f4"),
91+
np.dtype("f8, f8"),
92+
}:
93+
_fill = fill["f0"] + fill["f1"] * 1j
94+
elif hasattr(fill, "__len__") and len(fill) == 2:
95+
_fill = fill[0] + fill[1] * 1j
96+
else:
97+
_fill = fill
98+
self._fill = np.array(_fill, dtype=self.dtype)
8699
else:
87100
self._fill = np.array([fill], dtype=self.dtype)
88101

tiledb/cc/attribute.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ py::array get_fill_value(Attribute &attr) {
3737
value_num = 1;
3838
}
3939

40+
// complex type - both cell values fit in a single complex element
41+
if (value_type == py::dtype("complex64") ||
42+
value_type == py::dtype("complex128")) {
43+
value_num = 1;
44+
}
45+
4046
return py::array(value_type, value_num, value);
4147
}
4248

@@ -91,4 +97,4 @@ void init_attribute(py::module &m) {
9197
.def("_dump", [](Attribute &attr) { attr.dump(); });
9298
}
9399

94-
} // namespace libtiledbcpp
100+
} // namespace libtiledbcpp

tiledb/tests/test_attribute.py

+37
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,43 @@ def test_ncell_attribute(self):
8888
with self.assertRaises(TypeError):
8989
tiledb.Attr("foo", dtype=np.dtype([("", np.float32), ("", np.int32)]))
9090

91+
def test_complex64_attribute(self):
92+
attr = tiledb.Attr("foo", fill=(0 + 1j), dtype=np.dtype("complex64"))
93+
assert attr == attr
94+
assert attr.fill == attr.fill
95+
assert attr.dtype == np.complex64
96+
assert attr.ncells == 2
97+
98+
def test_complex128_attribute(self):
99+
dtype = np.dtype([("", np.double), ("", np.double)])
100+
attr = tiledb.Attr("foo", fill=(2.0, 2.0), dtype=dtype)
101+
102+
assert attr == attr
103+
assert attr.fill == attr.fill
104+
assert attr.dtype == np.complex128
105+
assert attr.ncells == 2
106+
107+
@pytest.mark.parametrize(
108+
"fill", [(1.0, 1.0), np.array((1.0, 1.0), dtype=np.dtype("f4, f4"))]
109+
)
110+
def test_two_cell_float_attribute(self, fill):
111+
attr = tiledb.Attr("foo", fill=fill, dtype=np.dtype("f4, f4"))
112+
113+
assert attr == attr
114+
assert attr.dtype == np.complex64
115+
assert attr.fill == attr.fill
116+
assert attr.ncells == 2
117+
118+
@pytest.mark.parametrize(
119+
"fill", [(1.0, 1.0), np.array((1.0, 1.0), dtype=np.dtype("f8, f8"))]
120+
)
121+
def test_two_cell_double_attribute(self, fill):
122+
attr = tiledb.Attr("foo", fill=fill, dtype=np.dtype("f8, f8"))
123+
assert attr == attr
124+
assert attr.dtype == np.complex128
125+
assert attr.fill == attr.fill
126+
assert attr.ncells == 2
127+
91128
def test_ncell_bytes_attribute(self):
92129
dtype = np.dtype((np.bytes_, 10))
93130
attr = tiledb.Attr("foo", dtype=dtype)

0 commit comments

Comments
 (0)