3838import dataclasses
3939import datetime as dt
4040import json
41+ import logging
4142from importlib .metadata import PackageNotFoundError , version
4243
4344import dask .array
5455except PackageNotFoundError :
5556 __version__ = "v?"
5657
58+ log = logging .getLogger ("nwp-consumer" )
59+
5760
5861@dataclasses .dataclass (slots = True )
5962class NWPDimensionCoordinateMap :
6063 """Container for dimensions names and their coordinate index values.
6164
62- Each field in the container is a dimension label, and the corresponding
65+ Each public field in the container is a dimension label, and the corresponding
6366 value is a list of the coordinate values for each index along the dimension.
6467
6568 All NWP data has an associated init time, step, and variable,
@@ -90,10 +93,7 @@ class NWPDimensionCoordinateMap:
9093 Will be truncated to 4 decimal places, and ordered as 90 -> -90.
9194 """
9295 longitude : list [float ] | None = None
93- """The longitude coordinates of the forecast grid in degrees.
94-
95- Will be truncated to 4 decimal places, and ordered as -180 -> 180.
96- """
96+ """The longitude coordinates of the forecast grid in degrees. """
9797
9898 def __post_init__ (self ) -> None :
9999 """Rigidly set input value ordering and precision."""
@@ -379,6 +379,8 @@ def determine_region(
379379 # TODO: of which might loop around the edges of the grid. In this case, it would
380380 # TODO: be useful to determine if the run is non-contiguous only in that it wraps
381381 # TODO: around that boundary, and in that case, split it and write it in two goes.
382+ # TODO: 2025-01-06: I think this is a resolved problem now that fetch_init_data
383+ # can return a list of DataArrays.
382384 return Failure (
383385 ValueError (
384386 f"Coordinate values for dimension '{ inner_dim_label } ' do not correspond "
@@ -393,46 +395,58 @@ def determine_region(
393395
394396 return Success (slices )
395397
396- def default_chunking (self ) -> dict [str , int ]:
398+ def chunking (self , chunk_count_overrides : dict [ str , int ] ) -> dict [str , int ]:
397399 """The expected chunk sizes for each dimension.
398400
399401 A dictionary mapping of dimension labels to the size of a chunk along that
400402 dimension. Note that the number is chunk size, not chunk number, so a chunk
401403 that wants to cover the entire dimension should have a size equal to the
402404 dimension length.
403405
404- It defaults to a single chunk per init time and step, and 8 chunks
405- for each entire other dimension. These are purposefully small, to ensure
406- that when perfomring parallel writes, chunk boundaries are not crossed.
406+ It defaults to a single chunk per init time, step, and variable coordinate,
407+ and 2 chunks for each entire other dimension, unless overridden by the
408+ `chunk_count_overrides` argument.
409+
410+ The defaults are purposefully small, to ensure that when performing parallel
411+ writes, chunk boundaries are not crossed.
412+
413+ Args:
414+ chunk_count_overrides: A dictionary mapping dimension labels to the
415+ number of chunks to split the dimension into.
407416 """
408417 out_dict : dict [str , int ] = {
409418 "init_time" : 1 ,
410419 "step" : 1 ,
420+ "variable" : 1 ,
411421 } | {
412- dim : len (getattr (self , dim )) // 8 if len (getattr (self , dim )) > 8 else 1
422+ dim : len (getattr (self , dim )) // chunk_count_overrides .get (dim , 2 )
423+ if len (getattr (self , dim )) > 8 else 1
413424 for dim in self .dims
414- if dim not in ["init_time" , "step" ]
425+ if dim not in ["init_time" , "step" , "variable" ]
415426 }
416427
417428 return out_dict
418429
419430
420- def as_zeroed_dataarray (self , name : str ) -> xr .DataArray :
431+ def as_zeroed_dataarray (self , name : str , chunks : dict [ str , int ] ) -> xr .DataArray :
421432 """Express the coordinates as an xarray DataArray.
422433
423- Data is populated with zeros and a default chunking scheme is applied.
434+ The underlying dask array is a zeroed array with the shape of the dataset,
435+ that is chunked according to the given chunking scheme.
424436
425437 Args:
426438 name: The name of the DataArray.
439+ chunks: A mapping of dimension names to the size of the chunks
440+ along the dimensions.
427441
428442 See Also:
429- - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
443+ - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
430444 """
431445 # Create a dask array of zeros with the shape of the dataset
432446 # * The values of this are ignored, only the shape and chunks are used
433447 dummy_values = dask .array .zeros ( # type: ignore
434448 shape = list (self .shapemap .values ()),
435- chunks = tuple ([self . default_chunking () [k ] for k in self .shapemap ]),
449+ chunks = tuple ([chunks [k ] for k in self .shapemap ]),
436450 )
437451 attrs : dict [str , str ] = {
438452 "produced_by" : "" .join ((
0 commit comments