Skip to content

Commit 0f51e30

Browse files
Speed up on, saving to cloud (#214)
Co-authored-by: devsjc <[email protected]>
1 parent db216ee commit 0f51e30

File tree

7 files changed

+101
-22
lines changed

7 files changed

+101
-22
lines changed

src/nwp_consumer/internal/entities/coordinates.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import dataclasses
3939
import datetime as dt
4040
import json
41+
import logging
4142
from importlib.metadata import PackageNotFoundError, version
4243

4344
import dask.array
@@ -54,12 +55,14 @@
5455
except PackageNotFoundError:
5556
__version__ = "v?"
5657

58+
log = logging.getLogger("nwp-consumer")
59+
5760

5861
@dataclasses.dataclass(slots=True)
5962
class NWPDimensionCoordinateMap:
6063
"""Container for dimensions names and their coordinate index values.
6164
62-
Each field in the container is a dimension label, and the corresponding
65+
Each public field in the container is a dimension label, and the corresponding
6366
value is a list of the coordinate values for each index along the dimension.
6467
6568
All NWP data has an associated init time, step, and variable,
@@ -90,10 +93,7 @@ class NWPDimensionCoordinateMap:
9093
Will be truncated to 4 decimal places, and ordered as 90 -> -90.
9194
"""
9295
longitude: list[float] | None = None
93-
"""The longitude coordinates of the forecast grid in degrees.
94-
95-
Will be truncated to 4 decimal places, and ordered as -180 -> 180.
96-
"""
96+
"""The longitude coordinates of the forecast grid in degrees. """
9797

9898
def __post_init__(self) -> None:
9999
"""Rigidly set input value ordering and precision."""
@@ -379,6 +379,8 @@ def determine_region(
379379
# TODO: of which might loop around the edges of the grid. In this case, it would
380380
# TODO: be useful to determine if the run is non-contiguous only in that it wraps
381381
# TODO: around that boundary, and in that case, split it and write it in two goes.
382+
# TODO: 2025-01-06: I think this is a resolved problem now that fetch_init_data
383+
# can return a list of DataArrays.
382384
return Failure(
383385
ValueError(
384386
f"Coordinate values for dimension '{inner_dim_label}' do not correspond "
@@ -393,46 +395,58 @@ def determine_region(
393395

394396
return Success(slices)
395397

396-
def default_chunking(self) -> dict[str, int]:
398+
def chunking(self, chunk_count_overrides: dict[str, int]) -> dict[str, int]:
397399
"""The expected chunk sizes for each dimension.
398400
399401
A dictionary mapping of dimension labels to the size of a chunk along that
400402
dimension. Note that the number is chunk size, not chunk number, so a chunk
401403
that wants to cover the entire dimension should have a size equal to the
402404
dimension length.
403405
404-
It defaults to a single chunk per init time and step, and 8 chunks
405-
for each entire other dimension. These are purposefully small, to ensure
406-
that when perfomring parallel writes, chunk boundaries are not crossed.
406+
It defaults to a single chunk per init time, step, and variable coordinate,
407+
and 2 chunks for each entire other dimension, unless overridden by the
408+
`chunk_count_overrides` argument.
409+
410+
The defaults are purposefully small, to ensure that when performing parallel
411+
writes, chunk boundaries are not crossed.
412+
413+
Args:
414+
chunk_count_overrides: A dictionary mapping dimension labels to the
415+
number of chunks to split the dimension into.
407416
"""
408417
out_dict: dict[str, int] = {
409418
"init_time": 1,
410419
"step": 1,
420+
"variable": 1,
411421
} | {
412-
dim: len(getattr(self, dim)) // 8 if len(getattr(self, dim)) > 8 else 1
422+
dim: len(getattr(self, dim)) // chunk_count_overrides.get(dim, 2)
423+
if len(getattr(self, dim)) > 8 else 1
413424
for dim in self.dims
414-
if dim not in ["init_time", "step"]
425+
if dim not in ["init_time", "step", "variable"]
415426
}
416427

417428
return out_dict
418429

419430

420-
def as_zeroed_dataarray(self, name: str) -> xr.DataArray:
431+
def as_zeroed_dataarray(self, name: str, chunks: dict[str, int]) -> xr.DataArray:
421432
"""Express the coordinates as an xarray DataArray.
422433
423-
Data is populated with zeros and a default chunking scheme is applied.
434+
The underlying dask array is a zeroed array with the shape of the dataset,
435+
that is chunked according to the given chunking scheme.
424436
425437
Args:
426438
name: The name of the DataArray.
439+
chunks: A mapping of dimension names to the size of the chunks
440+
along the dimensions.
427441
428442
See Also:
429-
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
443+
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
430444
"""
431445
# Create a dask array of zeros with the shape of the dataset
432446
# * The values of this are ignored, only the shape and chunks are used
433447
dummy_values = dask.array.zeros( # type: ignore
434448
shape=list(self.shapemap.values()),
435-
chunks=tuple([self.default_chunking()[k] for k in self.shapemap]),
449+
chunks=tuple([chunks[k] for k in self.shapemap]),
436450
)
437451
attrs: dict[str, str] = {
438452
"produced_by": "".join((

src/nwp_consumer/internal/entities/modelmetadata.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ class ModelMetadata:
5555
Which prints grid data from the grib file.
5656
"""
5757

58+
chunk_count_overrides: dict[str, int] = dataclasses.field(default_factory=dict)
59+
"""Mapping of dimension names to the desired number of chunks in that dimension.
60+
61+
Overrides the default chunking strategy.
62+
63+
See Also:
64+
- `entities.coordinates.NWPDimensionCoordinateMap.chunking`
65+
"""
66+
5867
def __str__(self) -> str:
5968
"""Return a pretty-printed string representation of the metadata."""
6069
pretty: str = "".join((
@@ -93,6 +102,14 @@ def with_region(self, region: str) -> "ModelMetadata":
93102
log.warning(f"Unknown region '{region}', not cropping expected coordinates.")
94103
return self
95104

105+
def with_chunk_count_overrides(self, overrides: dict[str, int]) -> "ModelMetadata":
106+
"""Returns metadata for the given model with the given chunk count overrides."""
107+
if not set(overrides.keys()).issubset(self.expected_coordinates.dims):
108+
log.warning(
109+
"Chunk count overrides contain keys not in the expected coordinates. "
110+
"These will not modify the chunking strategy.",
111+
)
112+
return dataclasses.replace(self, chunk_count_overrides=overrides)
96113

97114
class Models:
98115
"""Namespace containing known models."""

src/nwp_consumer/internal/entities/tensorstore.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import pathlib
1818
import shutil
19-
from collections.abc import MutableMapping
19+
from collections.abc import Mapping, MutableMapping
2020
from typing import Any
2121

2222
import pandas as pd
@@ -77,6 +77,7 @@ def initialize_empty_store(
7777
model: str,
7878
repository: str,
7979
coords: NWPDimensionCoordinateMap,
80+
chunks: dict[str, int],
8081
) -> ResultE["TensorStore"]:
8182
"""Initialize a store for a given init time.
8283
@@ -110,6 +111,7 @@ def initialize_empty_store(
110111
This is also used as the name of the tensor.
111112
repository: The name of the repository providing the tensor data.
112113
coords: The coordinates of the store.
114+
chunks: The chunk sizes for the store.
113115
114116
Returns:
115117
An indicator of a successful store write containing the number of bytes written.
@@ -152,7 +154,8 @@ def initialize_empty_store(
152154
# Write the coordinates to a skeleton Zarr store
153155
# * 'compute=False' enables only saving metadata
154156
# * 'mode="w-"' fails if it finds an existing store
155-
da: xr.DataArray = coords.as_zeroed_dataarray(name=model)
157+
158+
da: xr.DataArray = coords.as_zeroed_dataarray(name=model, chunks=chunks)
156159
encoding = {
157160
model: {"write_empty_chunks": False},
158161
"init_time": {"units": "nanoseconds since 1970-01-01"},
@@ -257,22 +260,52 @@ def write_to_region(
257260
If the region dict is empty or not provided, the region is determined
258261
via the 'determine_region' method.
259262
263+
This function should be thread safe, so a check is performed on the region
264+
to ensure that it can be safely written to in parallel, i.e. that it covers
265+
an integer number of chunks.
266+
260267
Args:
261268
da: The data to write to the store.
262269
region: The region to write to.
263270
264271
Returns:
265272
An indicator of a successful store write containing the number of bytes written.
273+
274+
See Also:
275+
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
266276
"""
267277
# Attempt to determine the region if missing
268278
if region is None or region == {}:
269279
region_result = NWPDimensionCoordinateMap.from_xarray(da).bind(
270280
self.coordinate_map.determine_region,
271281
)
272282
if isinstance(region_result, Failure):
273-
return Failure(region_result.failure())
283+
return region_result
274284
region = region_result.unwrap()
275285

286+
# For each dimensional slice defining the region, check the slice represents an
287+
# integer number of chunks along that dimension.
288+
# * This is to ensure that the data can safely be written in parallel.
289+
# * The start and and of each slice should be divisible by the chunk size.
290+
chunksizes: Mapping[Any, tuple[int, ...]] = xr.open_dataarray(
291+
self.path, engine="zarr",
292+
).chunksizes
293+
for dim, slc in region.items():
294+
chunk_size = chunksizes.get(dim, (1,))[0]
295+
# TODO: Determine if this should return a full failure object
296+
if slc.start % chunk_size != 0 or slc.stop % chunk_size != 0:
297+
log.warning(
298+
f"Determined region of raw data to be written for dimension '{dim}'"
299+
f"does not align with chunk boundaries of the store. "
300+
f"Dimension '{dim}' has a chunk size of {chunk_size}, "
301+
"but the data to be written for this dimension starts at chunk "
302+
f"{slc.start / chunk_size:.2f} (index {slc.start}) and ends at chunk "
303+
f"{slc.stop / chunk_size:.2f} (index {slc.stop}). "
304+
"As such, this region cannot be safely written in parallel. "
305+
"Ensure the chunking is granular enough to cover the raw data region.",
306+
)
307+
308+
276309
# Perform the regional write
277310
try:
278311
da.to_zarr(store=self.path, region=region, consolidated=True)

src/nwp_consumer/internal/entities/test_tensorstore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def store(self, year: int) -> Generator[TensorStore, None, None]:
8989
model="test_da",
9090
repository="dummy_repository",
9191
coords=test_coords,
92+
chunks=test_coords.chunking(chunk_count_overrides={}),
9293
)
9394
self.assertIsInstance(init_result, Success, msg=init_result)
9495
store = init_result.unwrap()

src/nwp_consumer/internal/repositories/raw_repositories/ceda_ftp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ def repository() -> entities.RawRepositoryMetadata:
121121
optional_env={},
122122
postprocess_options=entities.PostProcessOptions(),
123123
available_models={
124-
"default": entities.Models.MO_UM_GLOBAL_17KM,
124+
"default": entities.Models.MO_UM_GLOBAL_17KM.with_chunk_count_overrides({
125+
"latitude": 8,
126+
"longitude": 8,
127+
}),
125128
},
126129
)
127130

src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,17 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
194194
).with_suffix(".grib").expanduser()
195195

196196
# Only download the file if not already present
197-
if not local_path.exists() or local_path.stat().st_size == 0:
197+
if local_path.exists() and local_path.stat().st_size > 0:
198+
log.debug("Skipping download for existing file at '%s'.", local_path.as_posix())
199+
else:
198200
local_path.parent.mkdir(parents=True, exist_ok=True)
199201
log.debug("Requesting file from S3 at: '%s'", url)
200202

201203
try:
202204
if not self._fs.exists(url):
203205
raise FileNotFoundError(f"File not found at '{url}'")
204206

207+
log.debug("Writing file from '%s' to '%s'", url, local_path.as_posix())
205208
with local_path.open("wb") as lf, self._fs.open(url, "rb") as rf:
206209
for chunk in iter(lambda: rf.read(12 * 1024), b""):
207210
lf.write(chunk)
@@ -280,6 +283,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
280283
.sortby(variables=["step", "variable", "longitude"])
281284
.sortby(variables="latitude", ascending=False)
282285
)
286+
283287
except Exception as e:
284288
return Failure(ValueError(
285289
f"Error processing dataset {i} from '{path}' to DataArray: {e}",

src/nwp_consumer/internal/services/consumer_service.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,18 @@ def _parallelize_generator[T](
106106
max_connections: The maximum number of connections to use.
107107
"""
108108
# TODO: Change this based on threads instead of CPU count
109+
# TODO: Enable choosing between threads and processes?
109110
n_jobs: int = max(cpu_count() - 1, max_connections)
111+
prefer = "threads"
112+
110113
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
111114
n_jobs = 1
112-
log.debug(f"Using {n_jobs} concurrent thread(s)")
115+
116+
log.debug(f"Using {n_jobs} concurrent {prefer}")
113117

114118
return Parallel( # type: ignore
115119
n_jobs=n_jobs,
116-
prefer="threads",
120+
prefer=prefer,
117121
verbose=0,
118122
return_as="generator_unordered",
119123
)(delayed_generator)
@@ -151,6 +155,9 @@ def _create_suitable_store(
151155
model_metadata.expected_coordinates,
152156
init_time=its,
153157
),
158+
chunks=model_metadata.expected_coordinates.chunking(
159+
chunk_count_overrides=model_metadata.chunk_count_overrides,
160+
),
154161
)
155162

156163
@override

0 commit comments

Comments
 (0)