Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(coordinate): Log warning on unsafe regional writes #216

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions src/nwp_consumer/internal/entities/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import dataclasses
import datetime as dt
import json
import logging
from importlib.metadata import PackageNotFoundError, version

import dask.array
Expand All @@ -54,12 +55,14 @@
except PackageNotFoundError:
__version__ = "v?"

log = logging.getLogger("nwp-consumer")


@dataclasses.dataclass(slots=True)
class NWPDimensionCoordinateMap:
"""Container for dimensions names and their coordinate index values.

Each field in the container is a dimension label, and the corresponding
Each public field in the container is a dimension label, and the corresponding
value is a list of the coordinate values for each index along the dimension.

All NWP data has an associated init time, step, and variable,
Expand Down Expand Up @@ -91,12 +94,6 @@ class NWPDimensionCoordinateMap:
"""
longitude: list[float] | None = None
"""The longitude coordinates of the forecast grid in degrees. """
maximum_number_of_chunks_in_one_dim: int = 8
""" The maximum number of chunks in one dimension.
When saving to S3 we might want this to be small, to reduce the number of files saved.

Will be truncated to 4 decimal places, and ordered as -180 -> 180.
"""

def __post_init__(self) -> None:
"""Rigidly set input value ordering and precision."""
Expand All @@ -119,9 +116,7 @@ def dims(self) -> list[str]:
Ignores any dimensions that do not have a corresponding coordinate
index value list.
"""
return [f.name for f in dataclasses.fields(self) if
getattr(self, f.name) is not None
and f.name != "maximum_number_of_chunks_in_one_dim"]
return [f.name for f in dataclasses.fields(self) if getattr(self, f.name) is not None]

@property
def shapemap(self) -> dict[str, int]:
Expand Down Expand Up @@ -384,6 +379,8 @@ def determine_region(
# TODO: of which might loop around the edges of the grid. In this case, it would
# TODO: be useful to determine if the run is non-contiguous only in that it wraps
# TODO: around that boundary, and in that case, split it and write it in two goes.
# TODO: 2025-01-06: I think this is a resolved problem now that fetch_init_data
# can return a list of DataArrays.
return Failure(
ValueError(
f"Coordinate values for dimension '{inner_dim_label}' do not correspond "
Expand All @@ -398,47 +395,58 @@ def determine_region(

return Success(slices)

def default_chunking(self) -> dict[str, int]:
def chunking(self, chunk_count_overrides: dict[str, int]) -> dict[str, int]:
"""The expected chunk sizes for each dimension.

A dictionary mapping of dimension labels to the size of a chunk along that
dimension. Note that the number is chunk size, not chunk number, so a chunk
that wants to cover the entire dimension should have a size equal to the
dimension length.

It defaults to a single chunk per init time and step, and 8 chunks
for each entire other dimension. These are purposefully small, to ensure
that when perfomring parallel writes, chunk boundaries are not crossed.
It defaults to a single chunk per init time, step, and variable coordinate,
and 2 chunks for each entire other dimension, unless overridden by the
`chunk_count_overrides` argument.

The defaults are purposefully small, to ensure that when performing parallel
writes, chunk boundaries are not crossed.

Args:
chunk_count_overrides: A dictionary mapping dimension labels to the
number of chunks to split the dimension into.
"""
out_dict: dict[str, int] = {
"init_time": 1,
"step": 1,
"variable": 1,
} | {
dim: len(getattr(self, dim)) // self.maximum_number_of_chunks_in_one_dim
if len(getattr(self, dim)) > self.maximum_number_of_chunks_in_one_dim else 1
dim: len(getattr(self, dim)) // chunk_count_overrides.get(dim, 2)
if len(getattr(self, dim)) > 8 else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be >2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the size is 8 or under its probably fine to be in one chunk by default since that's pretty small anyway!

for dim in self.dims
if dim not in ["init_time", "step"]
if dim not in ["init_time", "step", "variable"]
}

return out_dict


def as_zeroed_dataarray(self, name: str) -> xr.DataArray:
def as_zeroed_dataarray(self, name: str, chunks: dict[str, int]) -> xr.DataArray:
"""Express the coordinates as an xarray DataArray.

Data is populated with zeros and a default chunking scheme is applied.
The underlying dask array is a zeroed array with the shape of the dataset,
that is chunked according to the given chunking scheme.

Args:
name: The name of the DataArray.
chunks: A mapping of dimension names to the size of the chunks
along the dimensions.

See Also:
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
"""
# Create a dask array of zeros with the shape of the dataset
# * The values of this are ignored, only the shape and chunks are used
dummy_values = dask.array.zeros( # type: ignore
shape=list(self.shapemap.values()),
chunks=tuple([self.default_chunking()[k] for k in self.shapemap]),
chunks=tuple([chunks[k] for k in self.shapemap]),
)
attrs: dict[str, str] = {
"produced_by": "".join((
Expand Down
24 changes: 17 additions & 7 deletions src/nwp_consumer/internal/entities/modelmetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ class ModelMetadata:
Which prints grid data from the grib file.
"""

chunk_count_overrides: dict[str, int] = dataclasses.field(default_factory=dict)
"""Mapping of dimension names to the desired number of chunks in that dimension.

Overrides the default chunking strategy.

See Also:
- `entities.coordinates.NWPDimensionCoordinateMap.chunking`
"""

def __str__(self) -> str:
"""Return a pretty-printed string representation of the metadata."""
pretty: str = "".join((
Expand Down Expand Up @@ -93,13 +102,14 @@ def with_region(self, region: str) -> "ModelMetadata":
log.warning(f"Unknown region '{region}', not cropping expected coordinates.")
return self

def set_maximum_number_of_chunks_in_one_dim(self, maximum_number_of_chunks_in_one_dim: int) \
-> "ModelMetadata":
"""Set the maximum number of chunks in one dimension."""
self.expected_coordinates.maximum_number_of_chunks_in_one_dim \
= maximum_number_of_chunks_in_one_dim
return self

def with_chunk_count_overrides(self, overrides: dict[str, int]) -> "ModelMetadata":
"""Returns metadata for the given model with the given chunk count overrides."""
if not set(overrides.keys()).issubset(self.expected_coordinates.dims):
log.warning(
"Chunk count overrides contain keys not in the expected coordinates. "
"These will not modify the chunking strategy.",
)
return dataclasses.replace(self, chunk_count_overrides=overrides)

class Models:
"""Namespace containing known models."""
Expand Down
39 changes: 36 additions & 3 deletions src/nwp_consumer/internal/entities/tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import pathlib
import shutil
from collections.abc import MutableMapping
from collections.abc import Mapping, MutableMapping
from typing import Any

import pandas as pd
Expand Down Expand Up @@ -77,6 +77,7 @@ def initialize_empty_store(
model: str,
repository: str,
coords: NWPDimensionCoordinateMap,
chunks: dict[str, int],
) -> ResultE["TensorStore"]:
"""Initialize a store for a given init time.

Expand Down Expand Up @@ -110,6 +111,7 @@ def initialize_empty_store(
This is also used as the name of the tensor.
repository: The name of the repository providing the tensor data.
coords: The coordinates of the store.
chunks: The chunk sizes for the store.

Returns:
An indicator of a successful store write containing the number of bytes written.
Expand Down Expand Up @@ -152,7 +154,8 @@ def initialize_empty_store(
# Write the coordinates to a skeleton Zarr store
# * 'compute=False' enables only saving metadata
# * 'mode="w-"' fails if it finds an existing store
da: xr.DataArray = coords.as_zeroed_dataarray(name=model)

da: xr.DataArray = coords.as_zeroed_dataarray(name=model, chunks=chunks)
encoding = {
model: {"write_empty_chunks": False},
"init_time": {"units": "nanoseconds since 1970-01-01"},
Expand Down Expand Up @@ -257,22 +260,52 @@ def write_to_region(
If the region dict is empty or not provided, the region is determined
via the 'determine_region' method.

This function should be thread safe, so a check is performed on the region
to ensure that it can be safely written to in parallel, i.e. that it covers
an integer number of chunks.

Args:
da: The data to write to the store.
region: The region to write to.

Returns:
An indicator of a successful store write containing the number of bytes written.

See Also:
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
"""
# Attempt to determine the region if missing
if region is None or region == {}:
region_result = NWPDimensionCoordinateMap.from_xarray(da).bind(
self.coordinate_map.determine_region,
)
if isinstance(region_result, Failure):
return Failure(region_result.failure())
return region_result
region = region_result.unwrap()

# For each dimensional slice defining the region, check the slice represents an
# integer number of chunks along that dimension.
# * This is to ensure that the data can safely be written in parallel.
# * The start and and of each slice should be divisible by the chunk size.
chunksizes: Mapping[Any, tuple[int, ...]] = xr.open_dataarray(
self.path, engine="zarr",
).chunksizes
for dim, slc in region.items():
chunk_size = chunksizes.get(dim, (1,))[0]
# TODO: Determine if this should return a full failure object
if slc.start % chunk_size != 0 or slc.stop % chunk_size != 0:
log.warning(
f"Determined region of raw data to be written for dimension '{dim}'"
f"does not align with chunk boundaries of the store. "
f"Dimension '{dim}' has a chunk size of {chunk_size}, "
"but the data to be written for this dimension starts at chunk "
f"{slc.start / chunk_size:.2f} (index {slc.start}) and ends at chunk "
f"{slc.stop / chunk_size:.2f} (index {slc.stop}). "
"As such, this region cannot be safely written in parallel. "
"Ensure the chunking is granular enough to cover the raw data region.",
)


# Perform the regional write
try:
da.to_zarr(store=self.path, region=region, consolidated=True)
Expand Down
1 change: 1 addition & 0 deletions src/nwp_consumer/internal/entities/test_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def store(self, year: int) -> Generator[TensorStore, None, None]:
model="test_da",
repository="dummy_repository",
coords=test_coords,
chunks=test_coords.chunking(chunk_count_overrides={}),
)
self.assertIsInstance(init_result, Success, msg=init_result)
store = init_result.unwrap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def repository() -> entities.RawRepositoryMetadata:
optional_env={},
postprocess_options=entities.PostProcessOptions(),
available_models={
"default": entities.Models.MO_UM_GLOBAL_17KM,
"default": entities.Models.MO_UM_GLOBAL_17KM.with_chunk_count_overrides({
"latitude": 8,
"longitude": 8,
}),
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,8 @@ def repository() -> entities.RawRepositoryMetadata:
},
postprocess_options=entities.PostProcessOptions(),
available_models={
"default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk")
.set_maximum_number_of_chunks_in_one_dim(2),
"hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk")
.set_maximum_number_of_chunks_in_one_dim(2),
"default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"),
"hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"),
"hres-ifs-india": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("india"),
},
)
Expand Down Expand Up @@ -196,17 +194,18 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
).with_suffix(".grib").expanduser()

# Only download the file if not already present
log.info("Checking for local file: '%s'", local_path)
if not local_path.exists() or local_path.stat().st_size == 0:
if local_path.exists() and local_path.stat().st_size > 0:
log.debug("Skipping download for existing file at '%s'.", local_path.as_posix())
else:
local_path.parent.mkdir(parents=True, exist_ok=True)
log.debug("Requesting file from S3 at: '%s'", url)

try:
if not self._fs.exists(url):
raise FileNotFoundError(f"File not found at '{url}'")

log.debug("Writing file from '%s' to '%s'", url, local_path.as_posix())
with local_path.open("wb") as lf, self._fs.open(url, "rb") as rf:
log.info(f"Writing file from {url} to {local_path}")
for chunk in iter(lambda: rf.read(12 * 1024), b""):
lf.write(chunk)
lf.flush()
Expand Down
8 changes: 5 additions & 3 deletions src/nwp_consumer/internal/services/consumer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,12 @@ def _parallelize_generator[T](
max_connections: The maximum number of connections to use.
"""
# TODO: Change this based on threads instead of CPU count
# TODO: Enable choosing between threads and processes?
n_jobs: int = max(cpu_count() - 1, max_connections)
prefer = "threads"

concurrency = os.getenv("CONCURRENCY", "True").capitalize() == "False"
if concurrency:
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
n_jobs = 1
prefer = "processes"

log.debug(f"Using {n_jobs} concurrent {prefer}")

Expand Down Expand Up @@ -156,6 +155,9 @@ def _create_suitable_store(
model_metadata.expected_coordinates,
init_time=its,
),
chunks=model_metadata.expected_coordinates.chunking(
chunk_count_overrides=model_metadata.chunk_count_overrides,
),
)

@override
Expand Down
Loading