diff --git a/src/nwp_consumer/internal/entities/coordinates.py b/src/nwp_consumer/internal/entities/coordinates.py index f32c561a..a369b51d 100644 --- a/src/nwp_consumer/internal/entities/coordinates.py +++ b/src/nwp_consumer/internal/entities/coordinates.py @@ -38,6 +38,7 @@ import dataclasses import datetime as dt import json +import logging from importlib.metadata import PackageNotFoundError, version import dask.array @@ -54,12 +55,14 @@ except PackageNotFoundError: __version__ = "v?" +log = logging.getLogger("nwp-consumer") + @dataclasses.dataclass(slots=True) class NWPDimensionCoordinateMap: """Container for dimensions names and their coordinate index values. - Each field in the container is a dimension label, and the corresponding + Each public field in the container is a dimension label, and the corresponding value is a list of the coordinate values for each index along the dimension. All NWP data has an associated init time, step, and variable, @@ -91,12 +94,6 @@ class NWPDimensionCoordinateMap: """ longitude: list[float] | None = None """The longitude coordinates of the forecast grid in degrees. """ - maximum_number_of_chunks_in_one_dim: int = 8 - """ The maximum number of chunks in one dimension. - When saving to S3 we might want this to be small, to reduce the number of files saved. - - Will be truncated to 4 decimal places, and ordered as -180 -> 180. - """ def __post_init__(self) -> None: """Rigidly set input value ordering and precision.""" @@ -119,9 +116,7 @@ def dims(self) -> list[str]: Ignores any dimensions that do not have a corresponding coordinate index value list. """ - return [f.name for f in dataclasses.fields(self) if - getattr(self, f.name) is not None - and f.name != "maximum_number_of_chunks_in_one_dim"] + return [f.name for f in dataclasses.fields(self) if getattr(self, f.name) is not None] @property def shapemap(self) -> dict[str, int]: @@ -384,6 +379,8 @@ def determine_region( # TODO: of which might loop around the edges of the grid. In this case, it would # TODO: be useful to determine if the run is non-contiguous only in that it wraps # TODO: around that boundary, and in that case, split it and write it in two goes. + # TODO: 2025-01-06: I think this is a resolved problem now that fetch_init_data + # can return a list of DataArrays. return Failure( ValueError( f"Coordinate values for dimension '{inner_dim_label}' do not correspond " @@ -398,7 +395,7 @@ def determine_region( return Success(slices) - def default_chunking(self) -> dict[str, int]: + def chunking(self, chunk_count_overrides: dict[str, int]) -> dict[str, int]: """The expected chunk sizes for each dimension. A dictionary mapping of dimension labels to the size of a chunk along that @@ -406,39 +403,50 @@ def default_chunking(self) -> dict[str, int]: that wants to cover the entire dimension should have a size equal to the dimension length. - It defaults to a single chunk per init time and step, and 8 chunks - for each entire other dimension. These are purposefully small, to ensure - that when perfomring parallel writes, chunk boundaries are not crossed. + It defaults to a single chunk per init time, step, and variable coordinate, + and 2 chunks for each entire other dimension, unless overridden by the + `chunk_count_overrides` argument. + + The defaults are purposefully small, to ensure that when performing parallel + writes, chunk boundaries are not crossed. + + Args: + chunk_count_overrides: A dictionary mapping dimension labels to the + number of chunks to split the dimension into. """ out_dict: dict[str, int] = { "init_time": 1, "step": 1, + "variable": 1, } | { - dim: len(getattr(self, dim)) // self.maximum_number_of_chunks_in_one_dim - if len(getattr(self, dim)) > self.maximum_number_of_chunks_in_one_dim else 1 + dim: len(getattr(self, dim)) // chunk_count_overrides.get(dim, 2) + if len(getattr(self, dim)) > 8 else 1 for dim in self.dims - if dim not in ["init_time", "step"] + if dim not in ["init_time", "step", "variable"] } return out_dict - def as_zeroed_dataarray(self, name: str) -> xr.DataArray: + def as_zeroed_dataarray(self, name: str, chunks: dict[str, int]) -> xr.DataArray: """Express the coordinates as an xarray DataArray. - Data is populated with zeros and a default chunking scheme is applied. + The underlying dask array is a zeroed array with the shape of the dataset, + that is chunked according to the given chunking scheme. Args: name: The name of the DataArray. + chunks: A mapping of dimension names to the size of the chunks + along the dimensions. See Also: - - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes + - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes """ # Create a dask array of zeros with the shape of the dataset # * The values of this are ignored, only the shape and chunks are used dummy_values = dask.array.zeros( # type: ignore shape=list(self.shapemap.values()), - chunks=tuple([self.default_chunking()[k] for k in self.shapemap]), + chunks=tuple([chunks[k] for k in self.shapemap]), ) attrs: dict[str, str] = { "produced_by": "".join(( diff --git a/src/nwp_consumer/internal/entities/modelmetadata.py b/src/nwp_consumer/internal/entities/modelmetadata.py index a5dbbc02..2efc671f 100644 --- a/src/nwp_consumer/internal/entities/modelmetadata.py +++ b/src/nwp_consumer/internal/entities/modelmetadata.py @@ -55,6 +55,15 @@ class ModelMetadata: Which prints grid data from the grib file. """ + chunk_count_overrides: dict[str, int] = dataclasses.field(default_factory=dict) + """Mapping of dimension names to the desired number of chunks in that dimension. + + Overrides the default chunking strategy. + + See Also: + - `entities.coordinates.NWPDimensionCoordinateMap.chunking` + """ + def __str__(self) -> str: """Return a pretty-printed string representation of the metadata.""" pretty: str = "".join(( @@ -93,13 +102,14 @@ def with_region(self, region: str) -> "ModelMetadata": log.warning(f"Unknown region '{region}', not cropping expected coordinates.") return self - def set_maximum_number_of_chunks_in_one_dim(self, maximum_number_of_chunks_in_one_dim: int) \ - -> "ModelMetadata": - """Set the maximum number of chunks in one dimension.""" - self.expected_coordinates.maximum_number_of_chunks_in_one_dim \ - = maximum_number_of_chunks_in_one_dim - return self - + def with_chunk_count_overrides(self, overrides: dict[str, int]) -> "ModelMetadata": + """Returns metadata for the given model with the given chunk count overrides.""" + if not set(overrides.keys()).issubset(self.expected_coordinates.dims): + log.warning( + "Chunk count overrides contain keys not in the expected coordinates. " + "These will not modify the chunking strategy.", + ) + return dataclasses.replace(self, chunk_count_overrides=overrides) class Models: """Namespace containing known models.""" diff --git a/src/nwp_consumer/internal/entities/tensorstore.py b/src/nwp_consumer/internal/entities/tensorstore.py index b71e3baf..4e866f1b 100644 --- a/src/nwp_consumer/internal/entities/tensorstore.py +++ b/src/nwp_consumer/internal/entities/tensorstore.py @@ -16,7 +16,7 @@ import os import pathlib import shutil -from collections.abc import MutableMapping +from collections.abc import Mapping, MutableMapping from typing import Any import pandas as pd @@ -77,6 +77,7 @@ def initialize_empty_store( model: str, repository: str, coords: NWPDimensionCoordinateMap, + chunks: dict[str, int], ) -> ResultE["TensorStore"]: """Initialize a store for a given init time. @@ -110,6 +111,7 @@ def initialize_empty_store( This is also used as the name of the tensor. repository: The name of the repository providing the tensor data. coords: The coordinates of the store. + chunks: The chunk sizes for the store. Returns: An indicator of a successful store write containing the number of bytes written. @@ -152,7 +154,8 @@ def initialize_empty_store( # Write the coordinates to a skeleton Zarr store # * 'compute=False' enables only saving metadata # * 'mode="w-"' fails if it finds an existing store - da: xr.DataArray = coords.as_zeroed_dataarray(name=model) + + da: xr.DataArray = coords.as_zeroed_dataarray(name=model, chunks=chunks) encoding = { model: {"write_empty_chunks": False}, "init_time": {"units": "nanoseconds since 1970-01-01"}, @@ -257,12 +260,19 @@ def write_to_region( If the region dict is empty or not provided, the region is determined via the 'determine_region' method. + This function should be thread safe, so a check is performed on the region + to ensure that it can be safely written to in parallel, i.e. that it covers + an integer number of chunks. + Args: da: The data to write to the store. region: The region to write to. Returns: An indicator of a successful store write containing the number of bytes written. + + See Also: + - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes """ # Attempt to determine the region if missing if region is None or region == {}: @@ -270,9 +280,32 @@ def write_to_region( self.coordinate_map.determine_region, ) if isinstance(region_result, Failure): - return Failure(region_result.failure()) + return region_result region = region_result.unwrap() + # For each dimensional slice defining the region, check the slice represents an + # integer number of chunks along that dimension. + # * This is to ensure that the data can safely be written in parallel. + # * The start and and of each slice should be divisible by the chunk size. + chunksizes: Mapping[Any, tuple[int, ...]] = xr.open_dataarray( + self.path, engine="zarr", + ).chunksizes + for dim, slc in region.items(): + chunk_size = chunksizes.get(dim, (1,))[0] + # TODO: Determine if this should return a full failure object + if slc.start % chunk_size != 0 or slc.stop % chunk_size != 0: + log.warning( + f"Determined region of raw data to be written for dimension '{dim}'" + f"does not align with chunk boundaries of the store. " + f"Dimension '{dim}' has a chunk size of {chunk_size}, " + "but the data to be written for this dimension starts at chunk " + f"{slc.start / chunk_size:.2f} (index {slc.start}) and ends at chunk " + f"{slc.stop / chunk_size:.2f} (index {slc.stop}). " + "As such, this region cannot be safely written in parallel. " + "Ensure the chunking is granular enough to cover the raw data region.", + ) + + # Perform the regional write try: da.to_zarr(store=self.path, region=region, consolidated=True) diff --git a/src/nwp_consumer/internal/entities/test_tensorstore.py b/src/nwp_consumer/internal/entities/test_tensorstore.py index 69b8a835..eab3f530 100644 --- a/src/nwp_consumer/internal/entities/test_tensorstore.py +++ b/src/nwp_consumer/internal/entities/test_tensorstore.py @@ -89,6 +89,7 @@ def store(self, year: int) -> Generator[TensorStore, None, None]: model="test_da", repository="dummy_repository", coords=test_coords, + chunks=test_coords.chunking(chunk_count_overrides={}), ) self.assertIsInstance(init_result, Success, msg=init_result) store = init_result.unwrap() diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/ceda_ftp.py b/src/nwp_consumer/internal/repositories/raw_repositories/ceda_ftp.py index 9b865c72..30cb001e 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/ceda_ftp.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/ceda_ftp.py @@ -121,7 +121,10 @@ def repository() -> entities.RawRepositoryMetadata: optional_env={}, postprocess_options=entities.PostProcessOptions(), available_models={ - "default": entities.Models.MO_UM_GLOBAL_17KM, + "default": entities.Models.MO_UM_GLOBAL_17KM.with_chunk_count_overrides({ + "latitude": 8, + "longitude": 8, + }), }, ) diff --git a/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py b/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py index 0d646a6d..ea297f4a 100644 --- a/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py +++ b/src/nwp_consumer/internal/repositories/raw_repositories/ecmwf_realtime.py @@ -83,10 +83,8 @@ def repository() -> entities.RawRepositoryMetadata: }, postprocess_options=entities.PostProcessOptions(), available_models={ - "default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk") - .set_maximum_number_of_chunks_in_one_dim(2), - "hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk") - .set_maximum_number_of_chunks_in_one_dim(2), + "default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"), + "hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"), "hres-ifs-india": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("india"), }, ) @@ -196,8 +194,9 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: ).with_suffix(".grib").expanduser() # Only download the file if not already present - log.info("Checking for local file: '%s'", local_path) - if not local_path.exists() or local_path.stat().st_size == 0: + if local_path.exists() and local_path.stat().st_size > 0: + log.debug("Skipping download for existing file at '%s'.", local_path.as_posix()) + else: local_path.parent.mkdir(parents=True, exist_ok=True) log.debug("Requesting file from S3 at: '%s'", url) @@ -205,8 +204,8 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: if not self._fs.exists(url): raise FileNotFoundError(f"File not found at '{url}'") + log.debug("Writing file from '%s' to '%s'", url, local_path.as_posix()) with local_path.open("wb") as lf, self._fs.open(url, "rb") as rf: - log.info(f"Writing file from {url} to {local_path}") for chunk in iter(lambda: rf.read(12 * 1024), b""): lf.write(chunk) lf.flush() diff --git a/src/nwp_consumer/internal/services/consumer_service.py b/src/nwp_consumer/internal/services/consumer_service.py index 45d6e93d..7064d662 100644 --- a/src/nwp_consumer/internal/services/consumer_service.py +++ b/src/nwp_consumer/internal/services/consumer_service.py @@ -106,13 +106,12 @@ def _parallelize_generator[T]( max_connections: The maximum number of connections to use. """ # TODO: Change this based on threads instead of CPU count + # TODO: Enable choosing between threads and processes? n_jobs: int = max(cpu_count() - 1, max_connections) prefer = "threads" - concurrency = os.getenv("CONCURRENCY", "True").capitalize() == "False" - if concurrency: + if os.getenv("CONCURRENCY", "True").capitalize() == "False": n_jobs = 1 - prefer = "processes" log.debug(f"Using {n_jobs} concurrent {prefer}") @@ -156,6 +155,9 @@ def _create_suitable_store( model_metadata.expected_coordinates, init_time=its, ), + chunks=model_metadata.expected_coordinates.chunking( + chunk_count_overrides=model_metadata.chunk_count_overrides, + ), ) @override