diff --git a/changes/2992.fix.rst b/changes/2992.fix.rst new file mode 100644 index 0000000000..7e4211cdde --- /dev/null +++ b/changes/2992.fix.rst @@ -0,0 +1,3 @@ +Fix a bug preventing ``ones_like``, ``full_like``, ``empty_like``, ``zeros_like`` and ``open_like`` functions from accepting +an explicit specification of array attributes like shape, dtype, chunks etc. The functions ``full_like``, +``empty_like``, and ``open_like`` now also more consistently infer a ``fill_value`` parameter from the provided array. diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 285d777258..c02f1e5392 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -108,10 +108,10 @@ def _get_shape_chunks(a: ArrayLike | Any) -> tuple[ChunkCoords | None, ChunkCoor return shape, chunks -def _like_args(a: ArrayLike, kwargs: dict[str, Any]) -> dict[str, Any]: +def _like_args(a: ArrayLike) -> dict[str, object]: """Set default values for shape and chunks if they are not present in the array-like object""" - new = kwargs.copy() + new: dict[str, object] = {} shape, chunks = _get_shape_chunks(a) if shape is not None: @@ -1077,7 +1077,7 @@ async def empty( shape: ChunkCoords, **kwargs: Any ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: """Create an empty array with the specified shape. The contents will be filled with the - array's fill value or zeros if no fill value is provided. + specified fill value or zeros if no fill value is provided. Parameters ---------- @@ -1092,8 +1092,7 @@ async def empty( retrieve data from an empty Zarr array, any values may be returned, and these are not guaranteed to be stable from one access to the next. """ - - return await create(shape=shape, fill_value=None, **kwargs) + return await create(shape=shape, **kwargs) async def empty_like( @@ -1120,8 +1119,10 @@ async def empty_like( retrieve data from an empty Zarr array, any values may be returned, and these are not guaranteed to be stable from one access to the next. """ - like_kwargs = _like_args(a, kwargs) - return await empty(**like_kwargs) + like_kwargs = _like_args(a) | kwargs + if isinstance(a, (AsyncArray | Array)): + like_kwargs.setdefault("fill_value", a.metadata.fill_value) + return await empty(**like_kwargs) # type: ignore[arg-type] # TODO: add type annotations for fill_value and kwargs @@ -1166,10 +1167,10 @@ async def full_like( Array The new array. """ - like_kwargs = _like_args(a, kwargs) - if isinstance(a, AsyncArray): + like_kwargs = _like_args(a) | kwargs + if isinstance(a, (AsyncArray | Array)): like_kwargs.setdefault("fill_value", a.metadata.fill_value) - return await full(**like_kwargs) + return await full(**like_kwargs) # type: ignore[arg-type] async def ones( @@ -1210,8 +1211,8 @@ async def ones_like( Array The new array. """ - like_kwargs = _like_args(a, kwargs) - return await ones(**like_kwargs) + like_kwargs = _like_args(a) | kwargs + return await ones(**like_kwargs) # type: ignore[arg-type] async def open_array( @@ -1291,10 +1292,10 @@ async def open_like( AsyncArray The opened array. """ - like_kwargs = _like_args(a, kwargs) + like_kwargs = _like_args(a) | kwargs if isinstance(a, (AsyncArray | Array)): - kwargs.setdefault("fill_value", a.metadata.fill_value) - return await open_array(path=path, **like_kwargs) + like_kwargs.setdefault("fill_value", a.metadata.fill_value) + return await open_array(path=path, **like_kwargs) # type: ignore[arg-type] async def zeros( @@ -1335,5 +1336,5 @@ async def zeros_like( Array The new array. """ - like_kwargs = _like_args(a, kwargs) - return await zeros(**like_kwargs) + like_kwargs = _like_args(a) | kwargs + return await zeros(**like_kwargs) # type: ignore[arg-type] diff --git a/tests/test_api.py b/tests/test_api.py index f03fd53f7a..67c77eb971 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, get_args if TYPE_CHECKING: import pathlib @@ -68,6 +68,81 @@ def test_create(memory_store: Store) -> None: z = create(shape=(400, 100), chunks=(16, 16.5), store=store, overwrite=True) # type: ignore [arg-type] +LikeFuncName = Literal["zeros_like", "ones_like", "empty_like", "full_like", "open_like"] + + +@pytest.mark.parametrize("func_name", get_args(LikeFuncName)) +@pytest.mark.parametrize("out_shape", ["keep", (10, 10)]) +@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)]) +@pytest.mark.parametrize("out_dtype", ["keep", "int8"]) +async def test_array_like_creation( + zarr_format: ZarrFormat, + func_name: LikeFuncName, + out_shape: Literal["keep"] | tuple[int, ...], + out_chunks: Literal["keep"] | tuple[int, ...], + out_dtype: str, +) -> None: + """ + Test zeros_like, ones_like, empty_like, full_like, ensuring that we can override the + shape, chunks, and dtype of the array-like object provided to these functions with + appropriate keyword arguments + """ + ref_arr = zarr.ones( + store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12), zarr_format=zarr_format + ) + kwargs: dict[str, object] = {} + if func_name == "full_like": + expect_fill = 4 + kwargs["fill_value"] = expect_fill + func = zarr.api.asynchronous.full_like + elif func_name == "zeros_like": + expect_fill = 0 + func = zarr.api.asynchronous.zeros_like + elif func_name == "ones_like": + expect_fill = 1 + func = zarr.api.asynchronous.ones_like + elif func_name == "empty_like": + expect_fill = ref_arr.fill_value + func = zarr.api.asynchronous.empty_like + elif func_name == "open_like": + expect_fill = ref_arr.fill_value + kwargs["mode"] = "w" + func = zarr.api.asynchronous.open_like # type: ignore[assignment] + else: + raise AssertionError + if out_shape != "keep": + kwargs["shape"] = out_shape + expect_shape = out_shape + else: + expect_shape = ref_arr.shape + if out_chunks != "keep": + kwargs["chunks"] = out_chunks + expect_chunks = out_chunks + else: + expect_chunks = ref_arr.chunks + if out_dtype != "keep": + kwargs["dtype"] = out_dtype + expect_dtype = out_dtype + else: + expect_dtype = ref_arr.dtype # type: ignore[assignment] + + new_arr = await func(ref_arr, path="foo", **kwargs) + assert new_arr.shape == expect_shape + assert new_arr.chunks == expect_chunks + assert new_arr.dtype == expect_dtype + assert np.all(Array(new_arr)[:] == expect_fill) + + +async def test_invalid_full_like() -> None: + """ + Test that a fill value that is incompatible with the proposed dtype is rejected + """ + ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12)) + fill = 4 + with pytest.raises(ValueError, match=f"fill value {fill} is not valid for dtype DataType.bool"): + await zarr.api.asynchronous.full_like(ref_arr, path="foo", fill_value=fill, dtype="bool") + + # TODO: parametrize over everything this function takes @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_create_array(store: Store) -> None: diff --git a/tests/test_group.py b/tests/test_group.py index 1e4f31b5d6..de2fc70e99 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -7,7 +7,7 @@ import re import time import warnings -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, get_args import numpy as np import pytest @@ -668,6 +668,66 @@ def test_group_create_array( assert np.array_equal(array[:], data) +LikeMethodName = Literal["zeros_like", "ones_like", "empty_like", "full_like"] + + +@pytest.mark.parametrize("method_name", get_args(LikeMethodName)) +@pytest.mark.parametrize("out_shape", ["keep", (10, 10)]) +@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)]) +@pytest.mark.parametrize("out_dtype", ["keep", "int8"]) +def test_group_array_like_creation( + zarr_format: ZarrFormat, + method_name: LikeMethodName, + out_shape: Literal["keep"] | tuple[int, ...], + out_chunks: Literal["keep"] | tuple[int, ...], + out_dtype: str, +) -> None: + """ + Test Group.{zeros_like, ones_like, empty_like, full_like}, ensuring that we can override the + shape, chunks, and dtype of the array-like object provided to these functions with + appropriate keyword arguments + """ + ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12)) + group = Group.from_store({}, zarr_format=zarr_format) + kwargs = {} + if method_name == "full_like": + expect_fill = 4 + kwargs["fill_value"] = expect_fill + meth = group.full_like + elif method_name == "zeros_like": + expect_fill = 0 + meth = group.zeros_like + elif method_name == "ones_like": + expect_fill = 1 + meth = group.ones_like + elif method_name == "empty_like": + expect_fill = ref_arr.fill_value + meth = group.empty_like + else: + raise AssertionError + if out_shape != "keep": + kwargs["shape"] = out_shape + expect_shape = out_shape + else: + expect_shape = ref_arr.shape + if out_chunks != "keep": + kwargs["chunks"] = out_chunks + expect_chunks = out_chunks + else: + expect_chunks = ref_arr.chunks + if out_dtype != "keep": + kwargs["dtype"] = out_dtype + expect_dtype = out_dtype + else: + expect_dtype = ref_arr.dtype + + new_arr = meth(name="foo", data=ref_arr, **kwargs) + assert new_arr.shape == expect_shape + assert new_arr.chunks == expect_chunks + assert new_arr.dtype == expect_dtype + assert np.all(new_arr[:] == expect_fill) + + def test_group_array_creation( store: Store, zarr_format: ZarrFormat,