Skip to content

Commit 108ec58

Browse files
TomNicholasd-v-b
andauthored
Add async oindex and vindex methods to AsyncArray (#3083)
Co-authored-by: Davis Bennett <[email protected]>
1 parent a0c56fb commit 108ec58

File tree

8 files changed

+270
-17
lines changed

8 files changed

+270
-17
lines changed

changes/3083.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added support for async vectorized and orthogonal indexing.

src/zarr/core/array.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
ZarrFormat,
6262
_default_zarr_format,
6363
_warn_order_kwarg,
64+
ceildiv,
6465
concurrent_map,
6566
parse_shapelike,
6667
product,
@@ -76,6 +77,8 @@
7677
)
7778
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
7879
from zarr.core.indexing import (
80+
AsyncOIndex,
81+
AsyncVIndex,
7982
BasicIndexer,
8083
BasicSelection,
8184
BlockIndex,
@@ -92,7 +95,6 @@
9295
Selection,
9396
VIndex,
9497
_iter_grid,
95-
ceildiv,
9698
check_fields,
9799
check_no_multi_fields,
98100
is_pure_fancy_indexing,
@@ -1425,6 +1427,56 @@ async def getitem(
14251427
)
14261428
return await self._get_selection(indexer, prototype=prototype)
14271429

1430+
async def get_orthogonal_selection(
1431+
self,
1432+
selection: OrthogonalSelection,
1433+
*,
1434+
out: NDBuffer | None = None,
1435+
fields: Fields | None = None,
1436+
prototype: BufferPrototype | None = None,
1437+
) -> NDArrayLikeOrScalar:
1438+
if prototype is None:
1439+
prototype = default_buffer_prototype()
1440+
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
1441+
return await self._get_selection(
1442+
indexer=indexer, out=out, fields=fields, prototype=prototype
1443+
)
1444+
1445+
async def get_mask_selection(
1446+
self,
1447+
mask: MaskSelection,
1448+
*,
1449+
out: NDBuffer | None = None,
1450+
fields: Fields | None = None,
1451+
prototype: BufferPrototype | None = None,
1452+
) -> NDArrayLikeOrScalar:
1453+
if prototype is None:
1454+
prototype = default_buffer_prototype()
1455+
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
1456+
return await self._get_selection(
1457+
indexer=indexer, out=out, fields=fields, prototype=prototype
1458+
)
1459+
1460+
async def get_coordinate_selection(
1461+
self,
1462+
selection: CoordinateSelection,
1463+
*,
1464+
out: NDBuffer | None = None,
1465+
fields: Fields | None = None,
1466+
prototype: BufferPrototype | None = None,
1467+
) -> NDArrayLikeOrScalar:
1468+
if prototype is None:
1469+
prototype = default_buffer_prototype()
1470+
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
1471+
out_array = await self._get_selection(
1472+
indexer=indexer, out=out, fields=fields, prototype=prototype
1473+
)
1474+
1475+
if hasattr(out_array, "shape"):
1476+
# restore shape
1477+
out_array = np.array(out_array).reshape(indexer.sel_shape)
1478+
return out_array
1479+
14281480
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
14291481
"""
14301482
Asynchronously save the array metadata.
@@ -1556,6 +1608,19 @@ async def setitem(
15561608
)
15571609
return await self._set_selection(indexer, value, prototype=prototype)
15581610

1611+
@property
1612+
def oindex(self) -> AsyncOIndex[T_ArrayMetadata]:
1613+
"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and
1614+
:func:`set_orthogonal_selection` for documentation and examples."""
1615+
return AsyncOIndex(self)
1616+
1617+
@property
1618+
def vindex(self) -> AsyncVIndex[T_ArrayMetadata]:
1619+
"""Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`,
1620+
:func:`set_coordinate_selection`, :func:`get_mask_selection` and
1621+
:func:`set_mask_selection` for documentation and examples."""
1622+
return AsyncVIndex(self)
1623+
15591624
async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None:
15601625
"""
15611626
Asynchronously resize the array to a new shape.

src/zarr/core/chunk_grids.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
ChunkCoords,
1919
ChunkCoordsLike,
2020
ShapeLike,
21+
ceildiv,
2122
parse_named_configuration,
2223
parse_shapelike,
2324
)
24-
from zarr.core.indexing import ceildiv
2525

2626
if TYPE_CHECKING:
2727
from collections.abc import Iterator

src/zarr/core/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import functools
5+
import math
56
import operator
67
import warnings
78
from collections.abc import Iterable, Mapping, Sequence
@@ -69,6 +70,12 @@ def product(tup: ChunkCoords) -> int:
6970
return functools.reduce(operator.mul, tup, 1)
7071

7172

73+
def ceildiv(a: float, b: float) -> int:
74+
if a == 0:
75+
return 0
76+
return math.ceil(a / b)
77+
78+
7279
T = TypeVar("T", bound=tuple[Any, ...])
7380
V = TypeVar("V")
7481

src/zarr/core/indexing.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import (
1313
TYPE_CHECKING,
1414
Any,
15+
Generic,
1516
Literal,
1617
NamedTuple,
1718
Protocol,
@@ -25,14 +26,16 @@
2526
import numpy as np
2627
import numpy.typing as npt
2728

28-
from zarr.core.common import product
29+
from zarr.core.common import ceildiv, product
30+
from zarr.core.metadata import T_ArrayMetadata
2931

3032
if TYPE_CHECKING:
31-
from zarr.core.array import Array
33+
from zarr.core.array import Array, AsyncArray
3234
from zarr.core.buffer import NDArrayLikeOrScalar
3335
from zarr.core.chunk_grids import ChunkGrid
3436
from zarr.core.common import ChunkCoords
3537

38+
3639
IntSequence = list[int] | npt.NDArray[np.intp]
3740
ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_]
3841
BasicSelector = int | slice | EllipsisType
@@ -93,12 +96,6 @@ class Indexer(Protocol):
9396
def __iter__(self) -> Iterator[ChunkProjection]: ...
9497

9598

96-
def ceildiv(a: float, b: float) -> int:
97-
if a == 0:
98-
return 0
99-
return math.ceil(a / b)
100-
101-
10299
_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"]
103100

104101

@@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N
960957
)
961958

962959

960+
@dataclass(frozen=True)
961+
class AsyncOIndex(Generic[T_ArrayMetadata]):
962+
array: AsyncArray[T_ArrayMetadata]
963+
964+
async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar:
965+
from zarr.core.array import Array
966+
967+
# if input is a Zarr array, we materialize it now.
968+
if isinstance(selection, Array):
969+
selection = _zarr_array_to_int_or_bool_array(selection)
970+
971+
fields, new_selection = pop_fields(selection)
972+
new_selection = ensure_tuple(new_selection)
973+
new_selection = replace_lists(new_selection)
974+
return await self.array.get_orthogonal_selection(
975+
cast(OrthogonalSelection, new_selection), fields=fields
976+
)
977+
978+
963979
@dataclass(frozen=True)
964980
class BlockIndexer(Indexer):
965981
dim_indexers: list[SliceDimIndexer]
@@ -1268,6 +1284,32 @@ def __setitem__(
12681284
raise VindexInvalidSelectionError(new_selection)
12691285

12701286

1287+
@dataclass(frozen=True)
1288+
class AsyncVIndex(Generic[T_ArrayMetadata]):
1289+
array: AsyncArray[T_ArrayMetadata]
1290+
1291+
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
1292+
async def getitem(
1293+
self, selection: CoordinateSelection | MaskSelection | Array
1294+
) -> NDArrayLikeOrScalar:
1295+
# TODO deduplicate these internals with the sync version of getitem
1296+
# TODO requires solving this circular sync issue: https://github.com/zarr-developers/zarr-python/pull/3083#discussion_r2230737448
1297+
from zarr.core.array import Array
1298+
1299+
# if input is a Zarr array, we materialize it now.
1300+
if isinstance(selection, Array):
1301+
selection = _zarr_array_to_int_or_bool_array(selection)
1302+
fields, new_selection = pop_fields(selection)
1303+
new_selection = ensure_tuple(new_selection)
1304+
new_selection = replace_lists(new_selection)
1305+
if is_coordinate_selection(new_selection, self.array.shape):
1306+
return await self.array.get_coordinate_selection(new_selection, fields=fields)
1307+
elif is_mask_selection(new_selection, self.array.shape):
1308+
return await self.array.get_mask_selection(new_selection, fields=fields)
1309+
else:
1310+
raise VindexInvalidSelectionError(new_selection)
1311+
1312+
12711313
def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]:
12721314
# early out
12731315
if fields is None:

tests/test_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype
4242
from zarr.core.chunk_grids import _auto_partition
4343
from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams
44-
from zarr.core.common import JSON, ZarrFormat
44+
from zarr.core.common import JSON, ZarrFormat, ceildiv
4545
from zarr.core.dtype import (
4646
DateTime64,
4747
Float32,
@@ -59,7 +59,7 @@
5959
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
6060
from zarr.core.dtype.npy.string import UTF8Base
6161
from zarr.core.group import AsyncGroup
62-
from zarr.core.indexing import BasicIndexer, ceildiv
62+
from zarr.core.indexing import BasicIndexer
6363
from zarr.core.metadata.v2 import ArrayV2Metadata
6464
from zarr.core.metadata.v3 import ArrayV3Metadata
6565
from zarr.core.sync import sync

tests/test_indexing.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,3 +1994,110 @@ def test_iter_chunk_regions():
19941994
assert_array_equal(a[region], np.ones_like(a[region]))
19951995
a[region] = 0
19961996
assert_array_equal(a[region], np.zeros_like(a[region]))
1997+
1998+
1999+
class TestAsync:
2000+
@pytest.mark.parametrize(
2001+
("indexer", "expected"),
2002+
[
2003+
# int
2004+
((0,), np.array([1, 2])),
2005+
((1,), np.array([3, 4])),
2006+
((0, 1), np.array(2)),
2007+
# slice
2008+
((slice(None),), np.array([[1, 2], [3, 4]])),
2009+
((slice(0, 1),), np.array([[1, 2]])),
2010+
((slice(1, 2),), np.array([[3, 4]])),
2011+
((slice(0, 2),), np.array([[1, 2], [3, 4]])),
2012+
((slice(0, 0),), np.empty(shape=(0, 2), dtype="i8")),
2013+
# ellipsis
2014+
((...,), np.array([[1, 2], [3, 4]])),
2015+
((0, ...), np.array([1, 2])),
2016+
((..., 0), np.array([1, 3])),
2017+
((0, 1, ...), np.array(2)),
2018+
# combined
2019+
((0, slice(None)), np.array([1, 2])),
2020+
((slice(None), 0), np.array([1, 3])),
2021+
((slice(None), slice(None)), np.array([[1, 2], [3, 4]])),
2022+
# array of ints
2023+
(([0]), np.array([[1, 2]])),
2024+
(([1]), np.array([[3, 4]])),
2025+
(([0], [1]), np.array(2)),
2026+
(([0, 1], [0]), np.array([[1], [3]])),
2027+
(([0, 1], [0, 1]), np.array([[1, 2], [3, 4]])),
2028+
# boolean array
2029+
(np.array([True, True]), np.array([[1, 2], [3, 4]])),
2030+
(np.array([True, False]), np.array([[1, 2]])),
2031+
(np.array([False, True]), np.array([[3, 4]])),
2032+
(np.array([False, False]), np.empty(shape=(0, 2), dtype="i8")),
2033+
],
2034+
)
2035+
@pytest.mark.asyncio
2036+
async def test_async_oindex(self, store, indexer, expected):
2037+
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2038+
z[...] = np.array([[1, 2], [3, 4]])
2039+
async_zarr = z._async_array
2040+
2041+
result = await async_zarr.oindex.getitem(indexer)
2042+
assert_array_equal(result, expected)
2043+
2044+
@pytest.mark.asyncio
2045+
async def test_async_oindex_with_zarr_array(self, store):
2046+
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2047+
z1[...] = np.array([[1, 2], [3, 4]])
2048+
async_zarr = z1._async_array
2049+
2050+
# create boolean zarr array to index with
2051+
z2 = zarr.create_array(
2052+
store=store, name="z2", shape=(2,), chunks=(1,), zarr_format=3, dtype="?"
2053+
)
2054+
z2[...] = np.array([True, False])
2055+
2056+
result = await async_zarr.oindex.getitem(z2)
2057+
expected = np.array([[1, 2]])
2058+
assert_array_equal(result, expected)
2059+
2060+
@pytest.mark.parametrize(
2061+
("indexer", "expected"),
2062+
[
2063+
(([0], [0]), np.array(1)),
2064+
(([0, 1], [0, 1]), np.array([1, 4])),
2065+
(np.array([[False, True], [False, True]]), np.array([2, 4])),
2066+
],
2067+
)
2068+
@pytest.mark.asyncio
2069+
async def test_async_vindex(self, store, indexer, expected):
2070+
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2071+
z[...] = np.array([[1, 2], [3, 4]])
2072+
async_zarr = z._async_array
2073+
2074+
result = await async_zarr.vindex.getitem(indexer)
2075+
assert_array_equal(result, expected)
2076+
2077+
@pytest.mark.asyncio
2078+
async def test_async_vindex_with_zarr_array(self, store):
2079+
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2080+
z1[...] = np.array([[1, 2], [3, 4]])
2081+
async_zarr = z1._async_array
2082+
2083+
# create boolean zarr array to index with
2084+
z2 = zarr.create_array(
2085+
store=store, name="z2", shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="?"
2086+
)
2087+
z2[...] = np.array([[False, True], [False, True]])
2088+
2089+
result = await async_zarr.vindex.getitem(z2)
2090+
expected = np.array([2, 4])
2091+
assert_array_equal(result, expected)
2092+
2093+
@pytest.mark.asyncio
2094+
async def test_async_invalid_indexer(self, store):
2095+
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2096+
z[...] = np.array([[1, 2], [3, 4]])
2097+
async_zarr = z._async_array
2098+
2099+
with pytest.raises(IndexError):
2100+
await async_zarr.vindex.getitem("invalid_indexer")
2101+
2102+
with pytest.raises(IndexError):
2103+
await async_zarr.oindex.getitem("invalid_indexer")

0 commit comments

Comments
 (0)