Skip to content

Commit da780f9

Browse files
authored
Merge branch 'main' into chore/benchmarks
2 parents 2b2d809 + 95585ee commit da780f9

File tree

6 files changed

+96
-20
lines changed

6 files changed

+96
-20
lines changed

changes/3547.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a `array.target_shard_size_bytes` to [`zarr.config`][] to allow users to set a maximum number of bytes per-shard when `shards="auto"` in, for example, [`zarr.create_array`][].

docs/user-guide/performance.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ z6 = zarr.create_array(store={}, shape=(10000, 10000, 1000), shards=(1000, 1000,
8181
print(z6.info)
8282
```
8383

84+
`shards` can be `"auto"` as well, in which case the `array.target_shard_size_bytes` setting can be used to control the size of shards (i.e., the size of the shard will be as close to without being bigger than `target_shard_size_bytes`); otherwise, a default is used.
85+
8486
### Chunk memory layout
8587

8688
The order of bytes **within each chunk** of an array can be changed via the

src/zarr/core/chunk_grids.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414

15+
import zarr
1516
from zarr.abc.metadata import Metadata
1617
from zarr.core.common import (
1718
JSON,
@@ -202,6 +203,43 @@ def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
202203
)
203204

204205

206+
def _guess_num_chunks_per_axis_shard(
207+
chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...]
208+
) -> int:
209+
"""Generate the number of chunks per axis to hit a target max byte size for a shard.
210+
211+
For example, for a (2,2,2) chunk size and item size 4, maximum bytes of 256 would return 2.
212+
In other words the shard would be a (2,2,2) grid of (2,2,2) chunks
213+
i.e., prod(chunk_shape) * (returned_val * len(chunk_shape)) * item_size = 256 bytes.
214+
215+
Parameters
216+
----------
217+
chunk_shape
218+
The shape of the (inner) chunks.
219+
item_size
220+
The item size of the data i.e., 2 for uint16.
221+
max_bytes
222+
The maximum number of bytes per shard to allow.
223+
array_shape
224+
The shape of the underlying array.
225+
226+
Returns
227+
-------
228+
The number of chunks per axis.
229+
"""
230+
bytes_per_chunk = np.prod(chunk_shape) * item_size
231+
if max_bytes < bytes_per_chunk:
232+
return 1
233+
num_axes = len(chunk_shape)
234+
chunks_per_shard = 1
235+
# First check for byte size, second check to make sure we don't go bigger than the array shape
236+
while (bytes_per_chunk * ((chunks_per_shard + 1) ** num_axes)) <= max_bytes and all(
237+
c * (chunks_per_shard + 1) <= a for c, a in zip(chunk_shape, array_shape, strict=True)
238+
):
239+
chunks_per_shard += 1
240+
return chunks_per_shard
241+
242+
205243
def _auto_partition(
206244
*,
207245
array_shape: tuple[int, ...],
@@ -237,12 +275,22 @@ def _auto_partition(
237275
stacklevel=2,
238276
)
239277
_shards_out = ()
278+
target_shard_size_bytes = zarr.config.get("array.target_shard_size_bytes", None)
279+
num_chunks_per_shard_axis = (
280+
_guess_num_chunks_per_axis_shard(
281+
chunk_shape=_chunks_out,
282+
item_size=item_size,
283+
max_bytes=target_shard_size_bytes,
284+
array_shape=array_shape,
285+
)
286+
if (has_auto_shard := (target_shard_size_bytes is not None))
287+
else 2
288+
)
240289
for a_shape, c_shape in zip(array_shape, _chunks_out, strict=True):
241-
# TODO: make a better heuristic than this.
242-
# for each axis, if there are more than 8 chunks along that axis, then put
243-
# 2 chunks in each shard for that axis.
244-
if a_shape // c_shape > 8:
245-
_shards_out += (c_shape * 2,)
290+
# The previous heuristic was `a_shape // c_shape > 8` and now, with target_shard_size_bytes, we only check that the shard size is less than the array size.
291+
can_shard_axis = a_shape // c_shape > 8 if not has_auto_shard else True
292+
if can_shard_axis:
293+
_shards_out += (c_shape * num_chunks_per_shard_axis,)
246294
else:
247295
_shards_out += (c_shape,)
248296
elif isinstance(shard_shape, dict):

src/zarr/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def enable_gpu(self) -> ConfigSet:
9696
"array": {
9797
"order": "C",
9898
"write_empty_chunks": False,
99+
"target_shard_size_bytes": None,
99100
},
100101
"async": {"concurrency": 10, "timeout": None},
101102
"threading": {"max_workers": None},

tests/test_array.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -966,33 +966,56 @@ async def test_nbytes(
966966

967967

968968
@pytest.mark.parametrize(
969-
("array_shape", "chunk_shape"),
970-
[((256,), (2,))],
969+
("array_shape", "chunk_shape", "target_shard_size_bytes", "expected_shards"),
970+
[
971+
pytest.param(
972+
(256, 256),
973+
(32, 32),
974+
129 * 129,
975+
(128, 128),
976+
id="2d_chunking_max_byes_does_not_evenly_divide",
977+
),
978+
pytest.param(
979+
(256, 256), (32, 32), 64 * 64, (64, 64), id="2d_chunking_max_byes_evenly_divides"
980+
),
981+
pytest.param(
982+
(256, 256),
983+
(64, 32),
984+
128 * 128,
985+
(128, 64),
986+
id="2d_non_square_chunking_max_byes_evenly_divides",
987+
),
988+
pytest.param((256,), (2,), 255, (254,), id="max_bytes_just_below_array_shape"),
989+
pytest.param((256,), (2,), 256, (256,), id="max_bytes_equal_to_array_shape"),
990+
pytest.param((256,), (2,), 16, (16,), id="max_bytes_normal_val"),
991+
pytest.param((256,), (2,), 2, (2,), id="max_bytes_same_as_chunk"),
992+
pytest.param((256,), (2,), 1, (2,), id="max_bytes_less_than_chunk"),
993+
pytest.param((256,), (2,), None, (4,), id="use_default_auto_setting"),
994+
pytest.param((4,), (2,), None, (2,), id="small_array_shape_does_not_shard"),
995+
],
971996
)
972997
def test_auto_partition_auto_shards(
973-
array_shape: tuple[int, ...], chunk_shape: tuple[int, ...]
998+
array_shape: tuple[int, ...],
999+
chunk_shape: tuple[int, ...],
1000+
target_shard_size_bytes: int | None,
1001+
expected_shards: tuple[int, ...],
9741002
) -> None:
9751003
"""
9761004
Test that automatically picking a shard size returns a tuple of 2 * the chunk shape for any axis
9771005
where there are 8 or more chunks.
9781006
"""
9791007
dtype = np.dtype("uint8")
980-
expected_shards: tuple[int, ...] = ()
981-
for cs, a_len in zip(chunk_shape, array_shape, strict=False):
982-
if a_len // cs >= 8:
983-
expected_shards += (2 * cs,)
984-
else:
985-
expected_shards += (cs,)
9861008
with pytest.warns(
9871009
ZarrUserWarning,
9881010
match="Automatic shard shape inference is experimental and may change without notice.",
9891011
):
990-
auto_shards, _ = _auto_partition(
991-
array_shape=array_shape,
992-
chunk_shape=chunk_shape,
993-
shard_shape="auto",
994-
item_size=dtype.itemsize,
995-
)
1012+
with zarr.config.set({"array.target_shard_size_bytes": target_shard_size_bytes}):
1013+
auto_shards, _ = _auto_partition(
1014+
array_shape=array_shape,
1015+
chunk_shape=chunk_shape,
1016+
shard_shape="auto",
1017+
item_size=dtype.itemsize,
1018+
)
9961019
assert auto_shards == expected_shards
9971020

9981021

tests/test_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_config_defaults_set() -> None:
5353
"array": {
5454
"order": "C",
5555
"write_empty_chunks": False,
56+
"target_shard_size_bytes": None,
5657
},
5758
"async": {"concurrency": 10, "timeout": None},
5859
"threading": {"max_workers": None},

0 commit comments

Comments
 (0)