Skip to content

Commit 8b59a38

Browse files
dcheriand-v-b
andauthored
Fix a bug when setting complete chunks (#2851)
* Fix a bug when setting complete chunks Closes #2849 * much simpler test * add release note * Update strategy priorities: 1. Emphasize arrays of side > 1, 2. Emphasize indexing the last chunk for both setitem & getitem * Use short node names * bug fix * Add scalar tests * [revert] * Add unit test * one more test * Add xfails * switch to skip, XPASS is not allwoed * Fix test * cleaniup --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent 96c9677 commit 8b59a38

File tree

7 files changed

+133
-36
lines changed

7 files changed

+133
-36
lines changed

changes/2851.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a bug when setting values of a smaller last chunk.

src/zarr/core/codec_pipeline.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,6 @@ def _merge_chunk_array(
296296
is_complete_chunk: bool,
297297
drop_axes: tuple[int, ...],
298298
) -> NDBuffer:
299-
if is_complete_chunk and value.shape == chunk_spec.shape:
300-
return value
301-
if existing_chunk_array is None:
302-
chunk_array = chunk_spec.prototype.nd_buffer.create(
303-
shape=chunk_spec.shape,
304-
dtype=chunk_spec.dtype,
305-
order=chunk_spec.order,
306-
fill_value=fill_value_or_default(chunk_spec),
307-
)
308-
else:
309-
chunk_array = existing_chunk_array.copy() # make a writable copy
310299
if chunk_selection == () or is_scalar(value.as_ndarray_like(), chunk_spec.dtype):
311300
chunk_value = value
312301
else:
@@ -320,6 +309,20 @@ def _merge_chunk_array(
320309
for idx in range(chunk_spec.ndim)
321310
)
322311
chunk_value = chunk_value[item]
312+
if is_complete_chunk and chunk_value.shape == chunk_spec.shape:
313+
# TODO: For the last chunk, we could have is_complete_chunk=True
314+
# that is smaller than the chunk_spec.shape but this throws
315+
# an error in the _decode_single
316+
return chunk_value
317+
if existing_chunk_array is None:
318+
chunk_array = chunk_spec.prototype.nd_buffer.create(
319+
shape=chunk_spec.shape,
320+
dtype=chunk_spec.dtype,
321+
order=chunk_spec.order,
322+
fill_value=fill_value_or_default(chunk_spec),
323+
)
324+
else:
325+
chunk_array = existing_chunk_array.copy() # make a writable copy
323326
chunk_array[chunk_selection] = chunk_value
324327
return chunk_array
325328

src/zarr/storage/_fsspec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def from_url(
178178
try:
179179
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
180180

181-
fs = AsyncFileSystemWrapper(fs)
181+
fs = AsyncFileSystemWrapper(fs, asynchronous=True)
182182
except ImportError as e:
183183
raise ImportError(
184184
f"The filesystem for URL '{url}' is synchronous, and the required "

src/zarr/testing/stateful.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def __init__(self, store: Store) -> None:
325325
def init_store(self) -> None:
326326
self.store.clear()
327327

328-
@rule(key=zarr_keys, data=st.binary(min_size=0, max_size=MAX_BINARY_SIZE))
328+
@rule(key=zarr_keys(), data=st.binary(min_size=0, max_size=MAX_BINARY_SIZE))
329329
def set(self, key: str, data: DataObject) -> None:
330330
note(f"(set) Setting {key!r} with {data}")
331331
assert not self.store.read_only
@@ -334,7 +334,7 @@ def set(self, key: str, data: DataObject) -> None:
334334
self.model[key] = data_buf
335335

336336
@precondition(lambda self: len(self.model.keys()) > 0)
337-
@rule(key=zarr_keys, data=st.data())
337+
@rule(key=zarr_keys(), data=st.data())
338338
def get(self, key: str, data: DataObject) -> None:
339339
key = data.draw(
340340
st.sampled_from(sorted(self.model.keys()))
@@ -344,7 +344,7 @@ def get(self, key: str, data: DataObject) -> None:
344344
# to bytes here necessary because data_buf set to model in set()
345345
assert self.model[key] == store_value
346346

347-
@rule(key=zarr_keys, data=st.data())
347+
@rule(key=zarr_keys(), data=st.data())
348348
def get_invalid_zarr_keys(self, key: str, data: DataObject) -> None:
349349
note("(get_invalid)")
350350
assume(key not in self.model)
@@ -408,7 +408,7 @@ def is_empty(self) -> None:
408408
# make sure they either both are or both aren't empty (same state)
409409
assert self.store.is_empty("") == (not self.model)
410410

411-
@rule(key=zarr_keys)
411+
@rule(key=zarr_keys())
412412
def exists(self, key: str) -> None:
413413
note("(exists)")
414414

src/zarr/testing/strategies.py

+71-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import math
12
import sys
23
from typing import Any, Literal
34

45
import hypothesis.extra.numpy as npst
56
import hypothesis.strategies as st
67
import numpy as np
7-
from hypothesis import given, settings # noqa: F401
8+
from hypothesis import event, given, settings # noqa: F401
89
from hypothesis.strategies import SearchStrategy
910

1011
import zarr
@@ -28,6 +29,16 @@
2829
)
2930

3031

32+
@st.composite # type: ignore[misc]
33+
def keys(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> Any:
34+
return draw(st.lists(node_names, min_size=1, max_size=max_num_nodes).map("/".join))
35+
36+
37+
@st.composite # type: ignore[misc]
38+
def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> Any:
39+
return draw(st.just("/") | keys(max_num_nodes=max_num_nodes))
40+
41+
3142
def v3_dtypes() -> st.SearchStrategy[np.dtype]:
3243
return (
3344
npst.boolean_dtypes()
@@ -87,17 +98,19 @@ def clear_store(x: Store) -> Store:
8798
node_names = st.text(zarr_key_chars, min_size=1).filter(
8899
lambda t: t not in (".", "..") and not t.startswith("__")
89100
)
101+
short_node_names = st.text(zarr_key_chars, max_size=3, min_size=1).filter(
102+
lambda t: t not in (".", "..") and not t.startswith("__")
103+
)
90104
array_names = node_names
91105
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
92-
keys = st.lists(node_names, min_size=1).map("/".join)
93-
paths = st.just("/") | keys
94106
# st.builds will only call a new store constructor for different keyword arguments
95107
# i.e. stores.examples() will always return the same object per Store class.
96108
# So we map a clear to reset the store.
97109
stores = st.builds(MemoryStore, st.just({})).map(clear_store)
98110
compressors = st.sampled_from([None, "default"])
99111
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([3, 2])
100-
array_shapes = npst.array_shapes(max_dims=4, min_side=0)
112+
# We de-prioritize arrays having dim sizes 0, 1, 2
113+
array_shapes = npst.array_shapes(max_dims=4, min_side=3) | npst.array_shapes(max_dims=4, min_side=0)
101114

102115

103116
@st.composite # type: ignore[misc]
@@ -152,13 +165,15 @@ def numpy_arrays(
152165
draw: st.DrawFn,
153166
*,
154167
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
155-
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
168+
dtype: np.dtype[Any] | None = None,
169+
zarr_formats: st.SearchStrategy[ZarrFormat] | None = zarr_formats,
156170
) -> Any:
157171
"""
158172
Generate numpy arrays that can be saved in the provided Zarr format.
159173
"""
160174
zarr_format = draw(zarr_formats)
161-
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
175+
if dtype is None:
176+
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
162177
if np.issubdtype(dtype, np.str_):
163178
safe_unicode_strings = safe_unicode_for_dtype(dtype)
164179
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))
@@ -174,11 +189,19 @@ def chunk_shapes(draw: st.DrawFn, *, shape: tuple[int, ...]) -> tuple[int, ...]:
174189
st.tuples(*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in shape])
175190
)
176191
# 2. and now generate the chunks tuple
177-
return tuple(
192+
chunks = tuple(
178193
size // nchunks if nchunks > 0 else 0
179194
for size, nchunks in zip(shape, numchunks, strict=True)
180195
)
181196

197+
for c in chunks:
198+
event("chunk size", c)
199+
200+
if any((c != 0 and s % c != 0) for s, c in zip(shape, chunks, strict=True)):
201+
event("smaller last chunk")
202+
203+
return chunks
204+
182205

183206
@st.composite # type: ignore[misc]
184207
def shard_shapes(
@@ -211,7 +234,7 @@ def arrays(
211234
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
212235
compressors: st.SearchStrategy = compressors,
213236
stores: st.SearchStrategy[StoreLike] = stores,
214-
paths: st.SearchStrategy[str | None] = paths,
237+
paths: st.SearchStrategy[str | None] = paths(), # noqa: B008
215238
array_names: st.SearchStrategy = array_names,
216239
arrays: st.SearchStrategy | None = None,
217240
attrs: st.SearchStrategy = attrs,
@@ -267,23 +290,56 @@ def arrays(
267290
return a
268291

269292

293+
@st.composite # type: ignore[misc]
294+
def simple_arrays(
295+
draw: st.DrawFn,
296+
*,
297+
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
298+
) -> Any:
299+
return draw(
300+
arrays(
301+
shapes=shapes,
302+
paths=paths(max_num_nodes=2),
303+
array_names=short_node_names,
304+
attrs=st.none(),
305+
compressors=st.sampled_from([None, "default"]),
306+
)
307+
)
308+
309+
270310
def is_negative_slice(idx: Any) -> bool:
271311
return isinstance(idx, slice) and idx.step is not None and idx.step < 0
272312

273313

314+
@st.composite # type: ignore[misc]
315+
def end_slices(draw: st.DrawFn, *, shape: tuple[int]) -> Any:
316+
"""
317+
A strategy that slices ranges that include the last chunk.
318+
This is intended to stress-test handling of a possibly smaller last chunk.
319+
"""
320+
slicers = []
321+
for size in shape:
322+
start = draw(st.integers(min_value=size // 2, max_value=size - 1))
323+
length = draw(st.integers(min_value=0, max_value=size - start))
324+
slicers.append(slice(start, start + length))
325+
event("drawing end slice")
326+
return tuple(slicers)
327+
328+
274329
@st.composite # type: ignore[misc]
275330
def basic_indices(draw: st.DrawFn, *, shape: tuple[int], **kwargs: Any) -> Any:
276331
"""Basic indices without unsupported negative slices."""
277-
return draw(
278-
npst.basic_indices(shape=shape, **kwargs).filter(
279-
lambda idxr: (
280-
not (
281-
is_negative_slice(idxr)
282-
or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr))
283-
)
332+
strategy = npst.basic_indices(shape=shape, **kwargs).filter(
333+
lambda idxr: (
334+
not (
335+
is_negative_slice(idxr)
336+
or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr))
284337
)
285338
)
286339
)
340+
if math.prod(shape) >= 3:
341+
strategy = end_slices(shape=shape) | strategy
342+
return draw(strategy)
287343

288344

289345
@st.composite # type: ignore[misc]

tests/test_indexing.py

+31
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,18 @@ def test_orthogonal_indexing_fallback_on_getitem_2d(
424424
np.testing.assert_array_equal(z[index], expected_result)
425425

426426

427+
@pytest.mark.skip(reason="fails on ubuntu, windows; numpy=2.2; in CI")
428+
def test_setitem_repeated_index():
429+
array = zarr.array(data=np.zeros((4,)), chunks=(1,))
430+
indexer = np.array([-1, -1, 0, 0])
431+
array.oindex[(indexer,)] = [0, 1, 2, 3]
432+
np.testing.assert_array_equal(array[:], np.array([3, 0, 0, 1]))
433+
434+
indexer = np.array([-1, 0, 0, -1])
435+
array.oindex[(indexer,)] = [0, 1, 2, 3]
436+
np.testing.assert_array_equal(array[:], np.array([2, 0, 0, 3]))
437+
438+
427439
Index = list[int] | tuple[slice | int | list[int], ...]
428440

429441

@@ -815,6 +827,25 @@ def test_set_orthogonal_selection_1d(store: StorePath) -> None:
815827
_test_set_orthogonal_selection(v, a, z, selection)
816828

817829

830+
def test_set_item_1d_last_two_chunks():
831+
# regression test for GH2849
832+
g = zarr.open_group("foo.zarr", zarr_format=3, mode="w")
833+
a = g.create_array("bar", shape=(10,), chunks=(3,), dtype=int)
834+
data = np.array([7, 8, 9])
835+
a[slice(7, 10)] = data
836+
np.testing.assert_array_equal(a[slice(7, 10)], data)
837+
838+
z = zarr.open_group("foo.zarr", mode="w")
839+
z.create_array("zoo", dtype=float, shape=())
840+
z["zoo"][...] = np.array(1) # why doesn't [:] work?
841+
np.testing.assert_equal(z["zoo"][()], np.array(1))
842+
843+
z = zarr.open_group("foo.zarr", mode="w")
844+
z.create_array("zoo", dtype=float, shape=())
845+
z["zoo"][...] = 1 # why doesn't [:] work?
846+
np.testing.assert_equal(z["zoo"][()], np.array(1))
847+
848+
818849
def _test_set_orthogonal_selection_2d(
819850
v: npt.NDArray[np.int_],
820851
a: npt.NDArray[np.int_],

tests/test_properties.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23
from numpy.testing import assert_array_equal
34

@@ -18,6 +19,7 @@
1819
basic_indices,
1920
numpy_arrays,
2021
orthogonal_indices,
22+
simple_arrays,
2123
stores,
2224
zarr_formats,
2325
)
@@ -50,13 +52,13 @@ def test_array_creates_implicit_groups(array):
5052

5153
@given(data=st.data())
5254
def test_basic_indexing(data: st.DataObject) -> None:
53-
zarray = data.draw(arrays())
55+
zarray = data.draw(simple_arrays())
5456
nparray = zarray[:]
5557
indexer = data.draw(basic_indices(shape=nparray.shape))
5658
actual = zarray[indexer]
5759
assert_array_equal(nparray[indexer], actual)
5860

59-
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
61+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
6062
zarray[indexer] = new_data
6163
nparray[indexer] = new_data
6264
assert_array_equal(nparray, zarray[:])
@@ -65,15 +67,19 @@ def test_basic_indexing(data: st.DataObject) -> None:
6567
@given(data=st.data())
6668
def test_oindex(data: st.DataObject) -> None:
6769
# integer_array_indices can't handle 0-size dimensions.
68-
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
70+
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
6971
nparray = zarray[:]
7072

7173
zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
7274
actual = zarray.oindex[zindexer]
7375
assert_array_equal(nparray[npindexer], actual)
7476

7577
assume(zarray.shards is None) # GH2834
76-
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
78+
for idxr in npindexer:
79+
if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size:
80+
# behaviour of setitem with repeated indices is not guaranteed in practice
81+
assume(False)
82+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
7783
nparray[npindexer] = new_data
7884
zarray.oindex[zindexer] = new_data
7985
assert_array_equal(nparray, zarray[:])
@@ -82,7 +88,7 @@ def test_oindex(data: st.DataObject) -> None:
8288
@given(data=st.data())
8389
def test_vindex(data: st.DataObject) -> None:
8490
# integer_array_indices can't handle 0-size dimensions.
85-
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
91+
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
8692
nparray = zarray[:]
8793

8894
indexer = data.draw(

0 commit comments

Comments
 (0)