Skip to content

Coalesce and parallelize partial shard reads #3004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/3004.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Optimizes reading more than one, but not all, chunks from a shard. Chunks are now read in parallel
and reads of nearby chunks within the same shard are combined to reduce the number of calls to the store.
See :ref:`user-guide-config` for more details.
6 changes: 6 additions & 0 deletions docs/user-guide/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ Configuration options include the following:
- Async and threading options, e.g. ``async.concurrency`` and ``threading.max_workers``
- Selections of implementations of codecs, codec pipelines and buffers
- Enabling GPU support with ``zarr.config.enable_gpu()``. See :ref:`user-guide-gpu` for more.
- Tuning reads from sharded zarrs. When reading less than a complete shard, reads of nearby chunks
within the same shard will be combined into a single request if they are less than
``sharding.read.coalesce_max_gap_bytes`` apart and the combined request size is less than
``sharding.read.coalesce_max_bytes``.

For selecting custom implementations of codecs, pipelines, buffers and ndbuffers,
first register the implementations in the registry and then select them in the config.
Expand Down Expand Up @@ -88,4 +92,6 @@ This is the current default configuration::
'default_zarr_format': 3,
'json_indent': 2,
'ndbuffer': 'zarr.core.buffer.cpu.NDBuffer',
'sharding': {'read': {'coalesce_max_bytes': 104857600,
'coalesce_max_gap_bytes': 1048576}},
'threading': {'max_workers': None}}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ ignore = [
[tool.mypy]
python_version = "3.11"
ignore_missing_imports = true
mypy_path = "src"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to do this for the mypy run in pre-commit to succeed when it was running on tests/test_config.py. Not sure if we want this at all, or if it should go in its own PR.

namespace_packages = false

strict = true
Expand Down
181 changes: 152 additions & 29 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
from zarr.core.common import (
ChunkCoords,
ChunkCoordsLike,
concurrent_map,
parse_enum,
parse_named_configuration,
parse_shapelike,
product,
)
from zarr.core.config import config
from zarr.core.indexing import (
BasicIndexer,
SelectorTuple,
Expand Down Expand Up @@ -196,7 +198,9 @@

@classmethod
def create_empty(
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
cls,
chunks_per_shard: ChunkCoords,
buffer_prototype: BufferPrototype | None = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a ruff format version upgrade? I'm leaving these formatting changes in since they were produced by the pre-commit run

) -> _ShardReader:
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
Expand Down Expand Up @@ -246,7 +250,9 @@

@classmethod
def create_empty(
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
cls,
chunks_per_shard: ChunkCoords,
buffer_prototype: BufferPrototype | None = None,
) -> _ShardBuilder:
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
Expand Down Expand Up @@ -327,9 +333,18 @@
return await shard_builder.finalize(index_location, index_encoder)


class _ChunkCoordsByteSlice(NamedTuple):
"""Holds a chunk's coordinates and its byte range in a serialized shard."""

coords: ChunkCoords
byte_slice: slice


@dataclass(frozen=True)
class ShardingCodec(
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
ArrayBytesCodec,
ArrayBytesCodecPartialDecodeMixin,
ArrayBytesCodecPartialEncodeMixin,
):
chunk_shape: ChunkCoords
codecs: tuple[Codec, ...]
Expand Down Expand Up @@ -439,7 +454,10 @@

# setup output array
out = chunk_spec.prototype.nd_buffer.create(
shape=shard_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
shape=shard_shape,
dtype=shard_spec.dtype,
order=shard_spec.order,
fill_value=0,
)
shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard)

Expand Down Expand Up @@ -483,39 +501,31 @@

# setup output array
out = shard_spec.prototype.nd_buffer.create(
shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
shape=indexer.shape,
dtype=shard_spec.dtype,
order=shard_spec.order,
fill_value=0,
)

indexed_chunks = list(indexer)
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}

# reading bytes of all requested chunks
shard_dict: ShardMapping = {}
shard_dict_maybe: ShardMapping | None = {}
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
# read entire shard
shard_dict_maybe = await self._load_full_shard_maybe(
byte_getter=byte_getter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
byte_getter, chunk_spec.prototype, chunks_per_shard
)
if shard_dict_maybe is None:
return None
shard_dict = shard_dict_maybe
else:
# read some chunks within the shard
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
if shard_index is None:
return None
shard_dict = {}
for chunk_coords in all_chunk_coords:
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
chunk_bytes = await byte_getter.get(
prototype=chunk_spec.prototype,
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
)
if chunk_bytes:
shard_dict[chunk_coords] = chunk_bytes
shard_dict_maybe = await self._load_partial_shard_maybe(

Check warning on line 522 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L522

Added line #L522 was not covered by tests
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords
)

if shard_dict_maybe is None:
return None

Check warning on line 527 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L527

Added line #L527 was not covered by tests
shard_dict = shard_dict_maybe
Comment on lines -501 to +528
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's where the non-formatting changes start in this file


# decoding chunks and writing them into the output buffer
await self.codec_pipeline.read(
Expand Down Expand Up @@ -597,7 +607,9 @@

indexer = list(
get_indexer(
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
selection,
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
)
)

Expand Down Expand Up @@ -671,7 +683,8 @@
get_pipeline_class()
.from_codecs(self.index_codecs)
.compute_encoded_size(
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard)
16 * product(chunks_per_shard),
self._get_index_chunk_spec(chunks_per_shard),
)
)

Expand Down Expand Up @@ -716,7 +729,8 @@
)
else:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
prototype=numpy_buffer_prototype(),
byte_range=SuffixByteRequest(shard_index_size),
)
if index_bytes is not None:
return await self._decode_shard_index(index_bytes, chunks_per_shard)
Expand All @@ -730,7 +744,10 @@
) or _ShardIndex.create_empty(chunks_per_shard)

async def _load_full_shard_maybe(
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords
self,
byte_getter: ByteGetter,
prototype: BufferPrototype,
chunks_per_shard: ChunkCoords,
) -> _ShardReader | None:
shard_bytes = await byte_getter.get(prototype=prototype)

Expand All @@ -740,6 +757,112 @@
else None
)

async def _load_partial_shard_maybe(
self,
byte_getter: ByteGetter,
prototype: BufferPrototype,
chunks_per_shard: ChunkCoords,
all_chunk_coords: set[ChunkCoords],
) -> ShardMapping | None:
"""
Read chunks from `byte_getter` for the case where the read is less than a full shard.
Returns a mapping of chunk coordinates to bytes.
"""
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
if shard_index is None:
return None

Check warning on line 773 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L771-L773

Added lines #L771 - L773 were not covered by tests

chunks = [

Check warning on line 775 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L775

Added line #L775 was not covered by tests
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
for chunk_coords in all_chunk_coords
# Drop chunks where index lookup fails
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
]
if len(chunks) < len(all_chunk_coords):
return None

Check warning on line 782 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L781-L782

Added lines #L781 - L782 were not covered by tests

groups = self._coalesce_chunks(chunks)

Check warning on line 784 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L784

Added line #L784 was not covered by tests

shard_dicts = await concurrent_map(

Check warning on line 786 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L786

Added line #L786 was not covered by tests
[(group, byte_getter, prototype) for group in groups],
self._get_group_bytes,
config.get("async.concurrency"),
)

shard_dict: ShardMutableMapping = {}
for d in shard_dicts:
if d is None:
return None
shard_dict.update(d)

Check warning on line 796 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L792-L796

Added lines #L792 - L796 were not covered by tests

return shard_dict

Check warning on line 798 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L798

Added line #L798 was not covered by tests

def _coalesce_chunks(
self,
chunks: list[_ChunkCoordsByteSlice],
) -> list[list[_ChunkCoordsByteSlice]]:
"""
Combine chunks from a single shard into groups that should be read together
in a single request.

Respects the following configuration options:
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
chunks to coalesce into a single group.
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
"""
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")

Check warning on line 814 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L813-L814

Added lines #L813 - L814 were not covered by tests

sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)

Check warning on line 816 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L816

Added line #L816 was not covered by tests

groups = []
current_group = [sorted_chunks[0]]

Check warning on line 819 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L818-L819

Added lines #L818 - L819 were not covered by tests

for chunk in sorted_chunks[1:]:
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
current_group.append(chunk)

Check warning on line 825 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L821-L825

Added lines #L821 - L825 were not covered by tests
else:
groups.append(current_group)
current_group = [chunk]

Check warning on line 828 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L827-L828

Added lines #L827 - L828 were not covered by tests

groups.append(current_group)

Check warning on line 830 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L830

Added line #L830 was not covered by tests

return groups

Check warning on line 832 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L832

Added line #L832 was not covered by tests

async def _get_group_bytes(
self,
group: list[_ChunkCoordsByteSlice],
byte_getter: ByteGetter,
prototype: BufferPrototype,
) -> ShardMapping | None:
"""
Reads a possibly coalesced group of one or more chunks from a shard.
Returns a mapping of chunk coordinates to bytes.
"""
group_start = group[0].byte_slice.start
group_end = group[-1].byte_slice.stop

Check warning on line 845 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L844-L845

Added lines #L844 - L845 were not covered by tests

# A single call to retrieve the bytes for the entire group.
group_bytes = await byte_getter.get(

Check warning on line 848 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L848

Added line #L848 was not covered by tests
prototype=prototype,
byte_range=RangeByteRequest(group_start, group_end),
)
if group_bytes is None:
return None

Check warning on line 853 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L852-L853

Added lines #L852 - L853 were not covered by tests

# Extract the bytes corresponding to each chunk in group from group_bytes.
shard_dict = {}
for chunk in group:
chunk_slice = slice(

Check warning on line 858 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L856-L858

Added lines #L856 - L858 were not covered by tests
chunk.byte_slice.start - group_start,
chunk.byte_slice.stop - group_start,
)
shard_dict[chunk.coords] = group_bytes[chunk_slice]

Check warning on line 862 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L862

Added line #L862 was not covered by tests

return shard_dict

Check warning on line 864 in src/zarr/codecs/sharding.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/codecs/sharding.py#L864

Added line #L864 was not covered by tests

def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int:
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
return input_byte_length + self._shard_index_size(chunks_per_shard)
Expand Down
6 changes: 6 additions & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def enable_gpu(self) -> ConfigSet:
},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
"sharding": {
"read": {
"coalesce_max_bytes": 100 * 2**20, # 100MiB
"coalesce_max_gap_bytes": 2**20, # 1MiB
}
},
"json_indent": 2,
"codec_pipeline": {
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@
stop = self.chunk_nitems_cumsum[chunk_rix]
out_selection: slice | npt.NDArray[np.intp]
if self.sel_sort is None:
out_selection = slice(start, stop)
out_selection = np.arange(start, stop)

Check warning on line 1196 in src/zarr/core/indexing.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/indexing.py#L1196

Added line #L1196 was not covered by tests
else:
out_selection = self.sel_sort[start:stop]

Expand Down
Loading
Loading