-
-
Notifications
You must be signed in to change notification settings - Fork 331
Coalesce and parallelize partial shard reads #3004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
42d4276
2f06574
b567599
c47b87c
0f5aa00
53de4a1
6cda80a
af0d144
094ab38
50dea49
fe077c4
85fe052
f039e0a
af9107a
35238bd
7d55768
4c69560
0dd4f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Optimizes reading more than one, but not all, chunks from a shard. Chunks are now read in parallel | ||
and reads of nearby chunks within the same shard are combined to reduce the number of calls to the store. | ||
See :ref:`user-guide-config` for more details. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,11 +38,13 @@ | |
from zarr.core.common import ( | ||
ChunkCoords, | ||
ChunkCoordsLike, | ||
concurrent_map, | ||
parse_enum, | ||
parse_named_configuration, | ||
parse_shapelike, | ||
product, | ||
) | ||
from zarr.core.config import config | ||
from zarr.core.indexing import ( | ||
BasicIndexer, | ||
SelectorTuple, | ||
|
@@ -196,7 +198,9 @@ | |
|
||
@classmethod | ||
def create_empty( | ||
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None | ||
cls, | ||
chunks_per_shard: ChunkCoords, | ||
buffer_prototype: BufferPrototype | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was there a ruff format version upgrade? I'm leaving these formatting changes in since they were produced by the pre-commit run |
||
) -> _ShardReader: | ||
if buffer_prototype is None: | ||
buffer_prototype = default_buffer_prototype() | ||
|
@@ -246,7 +250,9 @@ | |
|
||
@classmethod | ||
def create_empty( | ||
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None | ||
cls, | ||
chunks_per_shard: ChunkCoords, | ||
buffer_prototype: BufferPrototype | None = None, | ||
) -> _ShardBuilder: | ||
if buffer_prototype is None: | ||
buffer_prototype = default_buffer_prototype() | ||
|
@@ -327,9 +333,18 @@ | |
return await shard_builder.finalize(index_location, index_encoder) | ||
|
||
|
||
class _ChunkCoordsByteSlice(NamedTuple): | ||
"""Holds a chunk's coordinates and its byte range in a serialized shard.""" | ||
|
||
coords: ChunkCoords | ||
byte_slice: slice | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ShardingCodec( | ||
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin | ||
ArrayBytesCodec, | ||
ArrayBytesCodecPartialDecodeMixin, | ||
ArrayBytesCodecPartialEncodeMixin, | ||
): | ||
chunk_shape: ChunkCoords | ||
codecs: tuple[Codec, ...] | ||
|
@@ -439,7 +454,10 @@ | |
|
||
# setup output array | ||
out = chunk_spec.prototype.nd_buffer.create( | ||
shape=shard_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 | ||
shape=shard_shape, | ||
dtype=shard_spec.dtype, | ||
order=shard_spec.order, | ||
fill_value=0, | ||
) | ||
shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) | ||
|
||
|
@@ -483,39 +501,31 @@ | |
|
||
# setup output array | ||
out = shard_spec.prototype.nd_buffer.create( | ||
shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 | ||
shape=indexer.shape, | ||
dtype=shard_spec.dtype, | ||
order=shard_spec.order, | ||
fill_value=0, | ||
) | ||
|
||
indexed_chunks = list(indexer) | ||
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} | ||
|
||
# reading bytes of all requested chunks | ||
shard_dict: ShardMapping = {} | ||
shard_dict_maybe: ShardMapping | None = {} | ||
if self._is_total_shard(all_chunk_coords, chunks_per_shard): | ||
# read entire shard | ||
shard_dict_maybe = await self._load_full_shard_maybe( | ||
byte_getter=byte_getter, | ||
prototype=chunk_spec.prototype, | ||
chunks_per_shard=chunks_per_shard, | ||
byte_getter, chunk_spec.prototype, chunks_per_shard | ||
) | ||
if shard_dict_maybe is None: | ||
return None | ||
shard_dict = shard_dict_maybe | ||
else: | ||
# read some chunks within the shard | ||
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) | ||
if shard_index is None: | ||
return None | ||
shard_dict = {} | ||
for chunk_coords in all_chunk_coords: | ||
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) | ||
if chunk_byte_slice: | ||
chunk_bytes = await byte_getter.get( | ||
prototype=chunk_spec.prototype, | ||
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), | ||
) | ||
if chunk_bytes: | ||
shard_dict[chunk_coords] = chunk_bytes | ||
shard_dict_maybe = await self._load_partial_shard_maybe( | ||
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords | ||
) | ||
|
||
if shard_dict_maybe is None: | ||
return None | ||
shard_dict = shard_dict_maybe | ||
Comment on lines
-501
to
+528
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's where the non-formatting changes start in this file |
||
|
||
# decoding chunks and writing them into the output buffer | ||
await self.codec_pipeline.read( | ||
|
@@ -597,7 +607,9 @@ | |
|
||
indexer = list( | ||
get_indexer( | ||
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) | ||
selection, | ||
shape=shard_shape, | ||
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), | ||
) | ||
) | ||
|
||
|
@@ -671,7 +683,8 @@ | |
get_pipeline_class() | ||
.from_codecs(self.index_codecs) | ||
.compute_encoded_size( | ||
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) | ||
16 * product(chunks_per_shard), | ||
self._get_index_chunk_spec(chunks_per_shard), | ||
) | ||
) | ||
|
||
|
@@ -716,7 +729,8 @@ | |
) | ||
else: | ||
index_bytes = await byte_getter.get( | ||
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size) | ||
prototype=numpy_buffer_prototype(), | ||
byte_range=SuffixByteRequest(shard_index_size), | ||
) | ||
if index_bytes is not None: | ||
return await self._decode_shard_index(index_bytes, chunks_per_shard) | ||
|
@@ -730,7 +744,10 @@ | |
) or _ShardIndex.create_empty(chunks_per_shard) | ||
|
||
async def _load_full_shard_maybe( | ||
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords | ||
self, | ||
byte_getter: ByteGetter, | ||
prototype: BufferPrototype, | ||
chunks_per_shard: ChunkCoords, | ||
) -> _ShardReader | None: | ||
shard_bytes = await byte_getter.get(prototype=prototype) | ||
|
||
|
@@ -740,6 +757,112 @@ | |
else None | ||
) | ||
|
||
async def _load_partial_shard_maybe( | ||
self, | ||
byte_getter: ByteGetter, | ||
prototype: BufferPrototype, | ||
chunks_per_shard: ChunkCoords, | ||
all_chunk_coords: set[ChunkCoords], | ||
) -> ShardMapping | None: | ||
""" | ||
Read chunks from `byte_getter` for the case where the read is less than a full shard. | ||
Returns a mapping of chunk coordinates to bytes. | ||
""" | ||
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) | ||
if shard_index is None: | ||
return None | ||
|
||
chunks = [ | ||
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) | ||
for chunk_coords in all_chunk_coords | ||
# Drop chunks where index lookup fails | ||
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) | ||
] | ||
if len(chunks) < len(all_chunk_coords): | ||
return None | ||
|
||
groups = self._coalesce_chunks(chunks) | ||
|
||
shard_dicts = await concurrent_map( | ||
[(group, byte_getter, prototype) for group in groups], | ||
self._get_group_bytes, | ||
config.get("async.concurrency"), | ||
) | ||
|
||
shard_dict: ShardMutableMapping = {} | ||
for d in shard_dicts: | ||
if d is None: | ||
return None | ||
shard_dict.update(d) | ||
|
||
return shard_dict | ||
|
||
def _coalesce_chunks( | ||
self, | ||
chunks: list[_ChunkCoordsByteSlice], | ||
) -> list[list[_ChunkCoordsByteSlice]]: | ||
""" | ||
Combine chunks from a single shard into groups that should be read together | ||
in a single request. | ||
|
||
Respects the following configuration options: | ||
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between | ||
chunks to coalesce into a single group. | ||
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group. | ||
""" | ||
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes") | ||
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes") | ||
|
||
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start) | ||
|
||
groups = [] | ||
current_group = [sorted_chunks[0]] | ||
|
||
for chunk in sorted_chunks[1:]: | ||
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop | ||
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start | ||
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: | ||
current_group.append(chunk) | ||
else: | ||
groups.append(current_group) | ||
current_group = [chunk] | ||
|
||
groups.append(current_group) | ||
|
||
return groups | ||
|
||
async def _get_group_bytes( | ||
self, | ||
group: list[_ChunkCoordsByteSlice], | ||
byte_getter: ByteGetter, | ||
prototype: BufferPrototype, | ||
) -> ShardMapping | None: | ||
""" | ||
Reads a possibly coalesced group of one or more chunks from a shard. | ||
Returns a mapping of chunk coordinates to bytes. | ||
""" | ||
group_start = group[0].byte_slice.start | ||
group_end = group[-1].byte_slice.stop | ||
|
||
# A single call to retrieve the bytes for the entire group. | ||
group_bytes = await byte_getter.get( | ||
prototype=prototype, | ||
byte_range=RangeByteRequest(group_start, group_end), | ||
) | ||
if group_bytes is None: | ||
return None | ||
|
||
# Extract the bytes corresponding to each chunk in group from group_bytes. | ||
shard_dict = {} | ||
for chunk in group: | ||
chunk_slice = slice( | ||
chunk.byte_slice.start - group_start, | ||
chunk.byte_slice.stop - group_start, | ||
) | ||
shard_dict[chunk.coords] = group_bytes[chunk_slice] | ||
|
||
return shard_dict | ||
|
||
def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: | ||
chunks_per_shard = self._get_chunks_per_shard(shard_spec) | ||
return input_byte_length + self._shard_index_size(chunks_per_shard) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed to do this for the mypy run in pre-commit to succeed when it was running on tests/test_config.py. Not sure if we want this at all, or if it should go in its own PR.