diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index e83f5556369..2dcfcdb3860 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -179,14 +179,31 @@ def encode_zarr_attr_value(value): return encoded +def _is_coordinate_variable(zarr_array, name): + if _zarr_v3(): + if zarr_array.metadata.zarr_format == 2: + is_coordinate = name in zarr_array.metadata.attributes.get( + "_ARRAY_DIMENSIONS", [] + ) + else: + is_coordinate = name in (zarr_array.metadata.dimension_names or []) + else: + is_coordinate = name in zarr_array.attrs.get("_ARRAY_DIMENSIONS", []) + return is_coordinate + + class ZarrArrayWrapper(BackendArray): - __slots__ = ("_array", "dtype", "shape") + __slots__ = ("_array", "coords_buffer_prototype", "dtype", "is_coordinate", "shape") - def __init__(self, zarr_array): + def __init__( + self, zarr_array, is_coordinate: bool, coords_buffer_prototype: Any | None + ): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. self._array = zarr_array self.shape = self._array.shape + self.is_coordinate = is_coordinate + self.coords_buffer_prototype = coords_buffer_prototype # preserve vlen string object dtype (GH 7328) if ( @@ -210,7 +227,14 @@ def _vindex(self, key): return self._array.vindex[key] def _getitem(self, key): - return self._array[key] + kwargs = {} + if _zarr_v3(): + if self.is_coordinate: + prototype = self.coords_buffer_prototype + else: + prototype = None + kwargs["prototype"] = prototype + return self._array.get_basic_selection(key, **kwargs) def __getitem__(self, key): array = self._array @@ -605,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore): "_cache_members", "_close_store_on_close", "_consolidate_on_close", + "_coords_buffer_prototype", "_group", "_members", "_mode", @@ -636,6 +661,7 @@ def open_store( use_zarr_fill_value_as_mask=None, write_empty: bool | None = None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ): ( zarr_group, @@ -668,6 +694,7 @@ def open_store( close_store_on_close, use_zarr_fill_value_as_mask, cache_members=cache_members, + coords_buffer_prototype=coords_buffer_prototype, ) for group in group_paths } @@ -691,6 +718,7 @@ def open_group( use_zarr_fill_value_as_mask=None, write_empty: bool | None = None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ): ( zarr_group, @@ -722,6 +750,7 @@ def open_group( close_store_on_close, use_zarr_fill_value_as_mask, cache_members, + coords_buffer_prototype, ) def __init__( @@ -736,6 +765,7 @@ def __init__( close_store_on_close: bool = False, use_zarr_fill_value_as_mask=None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ): self.zarr_group = zarr_group self._read_only = self.zarr_group.read_only @@ -751,6 +781,14 @@ def __init__( self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask self._cache_members: bool = cache_members self._members: dict[str, ZarrArray | ZarrGroup] = {} + if _zarr_v3() and coords_buffer_prototype is None: + # Once zarr-v3 is required we can just have this as the default + # https://github.com/zarr-developers/zarr-python/issues/2871 + # Use the public API once available + from zarr.core.buffer.cpu import buffer_prototype + + coords_buffer_prototype = buffer_prototype + self._coords_buffer_prototype = coords_buffer_prototype if self._cache_members: # initialize the cache @@ -809,7 +847,15 @@ def ds(self): def open_store_variable(self, name): zarr_array = self.members[name] - data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) + is_coordinate = _is_coordinate_variable(zarr_array, name) + + data = indexing.LazilyIndexedArray( + ZarrArrayWrapper( + zarr_array, + is_coordinate=is_coordinate, + coords_buffer_prototype=self._coords_buffer_prototype, + ) + ) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( zarr_array, DIMENSION_KEY, try_nczarr @@ -1332,6 +1378,7 @@ def open_zarr( use_zarr_fill_value_as_mask=None, chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, + coords_buffer_prototype: Any | None = None, **kwargs, ): """Load and decode a dataset from a Zarr store. @@ -1442,6 +1489,12 @@ def open_zarr( chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg. Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + coords_buffer_prototype : zarr.buffer.BufferPrototype, optional + The buffer prototype to use for loading coordinate arrays. Zarr offers control over + which device's memory buffers are read into. By default, xarray will always load + *coordinate* buffers into host (CPU) memory, regardless of the global zarr + configuration. To override this behavior, explicitly pass the buffer prototype + to use for coordinates here. Returns ------- @@ -1485,6 +1538,7 @@ def open_zarr( "storage_options": storage_options, "zarr_version": zarr_version, "zarr_format": zarr_format, + "coords_buffer_prototype": coords_buffer_prototype, } ds = open_dataset( @@ -1557,6 +1611,7 @@ def open_dataset( engine=None, use_zarr_fill_value_as_mask=None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) if not store: @@ -1573,6 +1628,7 @@ def open_dataset( use_zarr_fill_value_as_mask=None, zarr_format=zarr_format, cache_members=cache_members, + coords_buffer_prototype=coords_buffer_prototype, ) store_entrypoint = StoreBackendEntrypoint() @@ -1608,6 +1664,7 @@ def open_datatree( storage_options=None, zarr_version=None, zarr_format=None, + coords_buffer_prototype: Any | None = None, ) -> DataTree: filename_or_obj = _normalize_path(filename_or_obj) groups_dict = self.open_groups_as_dict( @@ -1627,6 +1684,7 @@ def open_datatree( storage_options=storage_options, zarr_version=zarr_version, zarr_format=zarr_format, + coords_buffer_prototype=coords_buffer_prototype, ) return datatree_from_dict_with_io_cleanup(groups_dict) @@ -1650,6 +1708,7 @@ def open_groups_as_dict( storage_options=None, zarr_version=None, zarr_format=None, + coords_buffer_prototype: Any | None = None, ) -> dict[str, Dataset]: from xarray.core.treenode import NodePath @@ -1672,6 +1731,7 @@ def open_groups_as_dict( storage_options=storage_options, zarr_version=zarr_version, zarr_format=zarr_format, + coords_buffer_prototype=coords_buffer_prototype, ) groups_dict = {} diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a6df4d7b0cb..58334e09777 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3766,6 +3766,37 @@ def test_zarr_version_deprecated() -> None: xr.open_zarr(store=store, zarr_version=2, zarr_format=3) +@requires_zarr +def test_coords_buffer_prototype() -> None: + pytest.importorskip("zarr", minversion="3") + + from zarr.core.buffer import cpu + from zarr.core.buffer.core import BufferPrototype + + counter = 0 + + class Buffer(cpu.Buffer): + def __init__(self, *args, **kwargs): + nonlocal counter + counter += 1 + super().__init__(*args, **kwargs) + + class NDBuffer(cpu.NDBuffer): + def __init__(self, *args, **kwargs): + nonlocal counter + counter += 1 + super().__init__(*args, **kwargs) + + prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) + + ds = create_test_data() + store = KVStore() + # type-ignore for zarr v2/v3 compat, even though this test is skipped for v2 + ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload, unused-ignore] + xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore[arg-type, unused-ignore] + assert counter > 0 + + @requires_scipy class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only): engine: T_NetcdfEngine = "scipy"