Skip to content

fix:*-like creation routines take kwargs #2992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions changes/2992.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix a bug preventing ``ones_like``, ``full_like``, ``empty_like``, ``zeros_like`` and ``open_like`` functions from accepting
an explicit specification of array attributes like shape, dtype, chunks etc. The functions ``full_like``,
``empty_like``, and ``open_like`` now also more consistently infer a ``fill_value`` parameter from the provided array.
35 changes: 18 additions & 17 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@
return shape, chunks


def _like_args(a: ArrayLike, kwargs: dict[str, Any]) -> dict[str, Any]:
def _like_args(a: ArrayLike) -> dict[str, object]:
"""Set default values for shape and chunks if they are not present in the array-like object"""

new = kwargs.copy()
new: dict[str, object] = {}

Check warning on line 114 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L114

Added line #L114 was not covered by tests

shape, chunks = _get_shape_chunks(a)
if shape is not None:
Expand Down Expand Up @@ -1077,7 +1077,7 @@
shape: ChunkCoords, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Create an empty array with the specified shape. The contents will be filled with the
array's fill value or zeros if no fill value is provided.
specified fill value or zeros if no fill value is provided.

Parameters
----------
Expand All @@ -1092,8 +1092,7 @@
retrieve data from an empty Zarr array, any values may be returned,
and these are not guaranteed to be stable from one access to the next.
"""

return await create(shape=shape, fill_value=None, **kwargs)
return await create(shape=shape, **kwargs)

Check warning on line 1095 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1095

Added line #L1095 was not covered by tests


async def empty_like(
Expand All @@ -1120,8 +1119,10 @@
retrieve data from an empty Zarr array, any values may be returned,
and these are not guaranteed to be stable from one access to the next.
"""
like_kwargs = _like_args(a, kwargs)
return await empty(**like_kwargs)
like_kwargs = _like_args(a) | kwargs
if isinstance(a, (AsyncArray | Array)):
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
return await empty(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1125 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1122-L1125

Added lines #L1122 - L1125 were not covered by tests


# TODO: add type annotations for fill_value and kwargs
Expand Down Expand Up @@ -1166,10 +1167,10 @@
Array
The new array.
"""
like_kwargs = _like_args(a, kwargs)
if isinstance(a, AsyncArray):
like_kwargs = _like_args(a) | kwargs
if isinstance(a, (AsyncArray | Array)):

Check warning on line 1171 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1170-L1171

Added lines #L1170 - L1171 were not covered by tests
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
return await full(**like_kwargs)
return await full(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1173 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1173

Added line #L1173 was not covered by tests


async def ones(
Expand Down Expand Up @@ -1210,8 +1211,8 @@
Array
The new array.
"""
like_kwargs = _like_args(a, kwargs)
return await ones(**like_kwargs)
like_kwargs = _like_args(a) | kwargs
return await ones(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1215 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1214-L1215

Added lines #L1214 - L1215 were not covered by tests


async def open_array(
Expand Down Expand Up @@ -1291,10 +1292,10 @@
AsyncArray
The opened array.
"""
like_kwargs = _like_args(a, kwargs)
like_kwargs = _like_args(a) | kwargs

Check warning on line 1295 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1295

Added line #L1295 was not covered by tests
if isinstance(a, (AsyncArray | Array)):
kwargs.setdefault("fill_value", a.metadata.fill_value)
return await open_array(path=path, **like_kwargs)
like_kwargs.setdefault("fill_value", a.metadata.fill_value)
return await open_array(path=path, **like_kwargs) # type: ignore[arg-type]

Check warning on line 1298 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1297-L1298

Added lines #L1297 - L1298 were not covered by tests


async def zeros(
Expand Down Expand Up @@ -1335,5 +1336,5 @@
Array
The new array.
"""
like_kwargs = _like_args(a, kwargs)
return await zeros(**like_kwargs)
like_kwargs = _like_args(a) | kwargs
return await zeros(**like_kwargs) # type: ignore[arg-type]

Check warning on line 1340 in src/zarr/api/asynchronous.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/api/asynchronous.py#L1339-L1340

Added lines #L1339 - L1340 were not covered by tests
77 changes: 76 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, get_args

if TYPE_CHECKING:
import pathlib
Expand Down Expand Up @@ -68,6 +68,81 @@ def test_create(memory_store: Store) -> None:
z = create(shape=(400, 100), chunks=(16, 16.5), store=store, overwrite=True) # type: ignore [arg-type]


LikeFuncName = Literal["zeros_like", "ones_like", "empty_like", "full_like", "open_like"]


@pytest.mark.parametrize("func_name", get_args(LikeFuncName))
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
async def test_array_like_creation(
zarr_format: ZarrFormat,
func_name: LikeFuncName,
out_shape: Literal["keep"] | tuple[int, ...],
out_chunks: Literal["keep"] | tuple[int, ...],
out_dtype: str,
) -> None:
"""
Test zeros_like, ones_like, empty_like, full_like, ensuring that we can override the
shape, chunks, and dtype of the array-like object provided to these functions with
appropriate keyword arguments
"""
ref_arr = zarr.ones(
store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12), zarr_format=zarr_format
)
kwargs: dict[str, object] = {}
if func_name == "full_like":
expect_fill = 4
kwargs["fill_value"] = expect_fill
func = zarr.api.asynchronous.full_like
elif func_name == "zeros_like":
expect_fill = 0
func = zarr.api.asynchronous.zeros_like
elif func_name == "ones_like":
expect_fill = 1
func = zarr.api.asynchronous.ones_like
elif func_name == "empty_like":
expect_fill = ref_arr.fill_value
func = zarr.api.asynchronous.empty_like
elif func_name == "open_like":
expect_fill = ref_arr.fill_value
kwargs["mode"] = "w"
func = zarr.api.asynchronous.open_like # type: ignore[assignment]
else:
raise AssertionError
if out_shape != "keep":
kwargs["shape"] = out_shape
expect_shape = out_shape
else:
expect_shape = ref_arr.shape
if out_chunks != "keep":
kwargs["chunks"] = out_chunks
expect_chunks = out_chunks
else:
expect_chunks = ref_arr.chunks
if out_dtype != "keep":
kwargs["dtype"] = out_dtype
expect_dtype = out_dtype
else:
expect_dtype = ref_arr.dtype # type: ignore[assignment]

new_arr = await func(ref_arr, path="foo", **kwargs)
assert new_arr.shape == expect_shape
assert new_arr.chunks == expect_chunks
assert new_arr.dtype == expect_dtype
assert np.all(Array(new_arr)[:] == expect_fill)


async def test_invalid_full_like() -> None:
"""
Test that a fill value that is incompatible with the proposed dtype is rejected
"""
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
fill = 4
with pytest.raises(ValueError, match=f"fill value {fill} is not valid for dtype DataType.bool"):
await zarr.api.asynchronous.full_like(ref_arr, path="foo", fill_value=fill, dtype="bool")


# TODO: parametrize over everything this function takes
@pytest.mark.parametrize("store", ["memory"], indirect=True)
def test_create_array(store: Store) -> None:
Expand Down
62 changes: 61 additions & 1 deletion tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, get_args

import numpy as np
import pytest
Expand Down Expand Up @@ -668,6 +668,66 @@ def test_group_create_array(
assert np.array_equal(array[:], data)


LikeMethodName = Literal["zeros_like", "ones_like", "empty_like", "full_like"]


@pytest.mark.parametrize("method_name", get_args(LikeMethodName))
@pytest.mark.parametrize("out_shape", ["keep", (10, 10)])
@pytest.mark.parametrize("out_chunks", ["keep", (10, 10)])
@pytest.mark.parametrize("out_dtype", ["keep", "int8"])
def test_group_array_like_creation(
zarr_format: ZarrFormat,
method_name: LikeMethodName,
out_shape: Literal["keep"] | tuple[int, ...],
out_chunks: Literal["keep"] | tuple[int, ...],
out_dtype: str,
) -> None:
"""
Test Group.{zeros_like, ones_like, empty_like, full_like}, ensuring that we can override the
shape, chunks, and dtype of the array-like object provided to these functions with
appropriate keyword arguments
"""
ref_arr = zarr.ones(store={}, shape=(11, 12), dtype="uint8", chunks=(11, 12))
group = Group.from_store({}, zarr_format=zarr_format)
kwargs = {}
if method_name == "full_like":
expect_fill = 4
kwargs["fill_value"] = expect_fill
meth = group.full_like
elif method_name == "zeros_like":
expect_fill = 0
meth = group.zeros_like
elif method_name == "ones_like":
expect_fill = 1
meth = group.ones_like
elif method_name == "empty_like":
expect_fill = ref_arr.fill_value
meth = group.empty_like
else:
raise AssertionError
if out_shape != "keep":
kwargs["shape"] = out_shape
expect_shape = out_shape
else:
expect_shape = ref_arr.shape
if out_chunks != "keep":
kwargs["chunks"] = out_chunks
expect_chunks = out_chunks
else:
expect_chunks = ref_arr.chunks
if out_dtype != "keep":
kwargs["dtype"] = out_dtype
expect_dtype = out_dtype
else:
expect_dtype = ref_arr.dtype

new_arr = meth(name="foo", data=ref_arr, **kwargs)
assert new_arr.shape == expect_shape
assert new_arr.chunks == expect_chunks
assert new_arr.dtype == expect_dtype
assert np.all(new_arr[:] == expect_fill)


def test_group_array_creation(
store: Store,
zarr_format: ZarrFormat,
Expand Down