Skip to content

Commit a3dda3a

Browse files
committed
Merge branch 'develop'
2 parents 7cc1eea + e9ad943 commit a3dda3a

File tree

9 files changed

+145
-57
lines changed

9 files changed

+145
-57
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
- id: debug-statements # Check for debugger imports and py37+ breakpoint()
2121
- id: end-of-file-fixer # Ensure files end in a newline
2222
- id: trailing-whitespace # Trailing whitespace checker
23-
- id: no-commit-to-branch # Prevent committing to main / master
23+
# - id: no-commit-to-branch # Prevent committing to main / master
2424
- id: check-added-large-files # Check for large files added to git
2525
- id: check-merge-conflict # Check for files that contain merge conflict
2626

docs/building/handling-missing-values.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,16 @@
22
Handling missing values
33
#########################
44

5-
.. literalinclude:: ../../tests/create/nan.yaml
5+
When handling data for machine learning models, missing values (NaNs)
6+
can pose a challenge, as models require complete data to operate
7+
effectively and may crash otherwise. Ideally, we anticipate having
8+
complete data in all fields. However, there are scenarios where NaNs
9+
naturally occur, such as with variables only relevant on land or at sea
10+
(such as sea surface temperature (`sst`), for example). In such cases,
11+
the default behavior is to reject data with NaNs as invalid. To
12+
accommodate NaNs and accurately compute statistics based on them, you
13+
can include the `allow_nans` key in the configuration. Here's an example
14+
of how to implement it:
15+
16+
.. literalinclude:: yaml/nan.yaml
617
:language: yaml

docs/building/statistics.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ algorithm:
1717
- If the dataset covers 10 years or more, the last year is excluded.
1818
- Otherwise, 80% of the dataset is used.
1919

20-
You can override this behaviour by setting the `start` and `end`
20+
You can override this behaviour by setting the `start` or `end`
2121
parameters in the `statistics` config.
2222

2323
.. code:: yaml

docs/building/yaml/nan.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
statistics:
2+
allow_nans: [sst, ci]

src/anemoi/datasets/compute/perturbations.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,59 +7,95 @@
77
# nor does it submit to any jurisdiction.
88
#
99

10-
import warnings
10+
import logging
1111

1212
import numpy as np
1313
from climetlab.core.temporary import temp_file
1414
from climetlab.readers.grib.output import new_grib_output
1515

16-
from anemoi.datasets.create.check import check_data_values
1716
from anemoi.datasets.create.functions import assert_is_fieldset
1817

18+
LOG = logging.getLogger(__name__)
19+
20+
CLIP_VARIABLES = (
21+
"q",
22+
"cp",
23+
"lsp",
24+
"tp",
25+
"sf",
26+
"swl4",
27+
"swl3",
28+
"swl2",
29+
"swl1",
30+
)
31+
32+
SKIP = ("class", "stream", "type", "number", "expver", "_leg_number", "anoffset")
33+
34+
35+
def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars):
36+
assert f1.mars_grid == f2.mars_grid, (f1.mars_grid, f2.mars_grid)
37+
assert f1.mars_area == f2.mars_area, (f1.mars_area, f2.mars_area)
38+
assert f1.shape == f2.shape, (f1.shape, f2.shape)
39+
40+
# Not in *_as_mars
41+
assert f1.metadata("valid_datetime") == f2.metadata("valid_datetime"), (
42+
f1.metadata("valid_datetime"),
43+
f2.metadata("valid_datetime"),
44+
)
45+
46+
for k in set(center_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()):
47+
if k in SKIP:
48+
continue
49+
assert center_field_as_mars[k] == ensemble_field_as_mars[k], (
50+
k,
51+
center_field_as_mars[k],
52+
ensemble_field_as_mars[k],
53+
)
54+
1955

2056
def perturbations(
57+
*,
2158
members,
2259
center,
23-
positive_clipping_variables=[
24-
"q",
25-
"cp",
26-
"lsp",
27-
"tp",
28-
], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ?
60+
clip_variables=CLIP_VARIABLES,
61+
output=None,
2962
):
3063

3164
keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]
3265

33-
def check_compatible(f1, f2, ignore=["number"]):
34-
for k in keys + ["grid", "shape"]:
35-
if k in ignore:
36-
continue
37-
assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k))
66+
number_list = members.unique_values("number")["number"]
67+
n_numbers = len(number_list)
3868

39-
print(f"Retrieving ensemble data with {members}")
40-
print(f"Retrieving center data with {center}")
69+
assert None not in number_list
4170

71+
LOG.info("Ordering fields")
4272
members = members.order_by(*keys)
4373
center = center.order_by(*keys)
44-
45-
number_list = members.unique_values("number")["number"]
46-
n_numbers = len(number_list)
74+
LOG.info("Done")
4775

4876
if len(center) * n_numbers != len(members):
49-
print(len(center), n_numbers, len(members))
77+
LOG.error("%s %s %s", len(center), n_numbers, len(members))
5078
for f in members:
51-
print("Member: ", f)
79+
LOG.error("Member: %r", f)
5280
for f in center:
53-
print("Center: ", f)
81+
LOG.error("Center: %r", f)
5482
raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}")
5583

56-
# prepare output tmp file so we can read it back
57-
tmp = temp_file()
58-
path = tmp.path
84+
if output is None:
85+
# prepare output tmp file so we can read it back
86+
tmp = temp_file()
87+
path = tmp.path
88+
else:
89+
tmp = None
90+
path = output
91+
5992
out = new_grib_output(path)
6093

94+
seen = set()
95+
6196
for i, center_field in enumerate(center):
6297
param = center_field.metadata("param")
98+
center_field_as_mars = center_field.as_mars()
6399

64100
# load the center field
65101
center_np = center_field.to_numpy()
@@ -69,9 +105,21 @@ def check_compatible(f1, f2, ignore=["number"]):
69105

70106
for j in range(n_numbers):
71107
ensemble_field = members[i * n_numbers + j]
72-
check_compatible(center_field, ensemble_field)
108+
ensemble_field_as_mars = ensemble_field.as_mars()
109+
check_compatible(center_field, ensemble_field, center_field_as_mars, ensemble_field_as_mars)
73110
members_np[j] = ensemble_field.to_numpy()
74111

112+
ensemble_field_as_mars = tuple(sorted(ensemble_field_as_mars.items()))
113+
assert ensemble_field_as_mars not in seen, ensemble_field_as_mars
114+
seen.add(ensemble_field_as_mars)
115+
116+
# cmin=np.amin(center_np)
117+
# emin=np.amin(members_np)
118+
119+
# if cmin < 0 and emin >= 0:
120+
# LOG.warning(f"Negative values in {param} cmin={cmin} emin={emin}")
121+
# LOG.warning(f"Center: {center_field_as_mars}")
122+
75123
mean_np = members_np.mean(axis=0)
76124

77125
for j in range(n_numbers):
@@ -84,18 +132,22 @@ def check_compatible(f1, f2, ignore=["number"]):
84132

85133
x = c - m + e
86134

87-
if param in positive_clipping_variables:
88-
warnings.warn(f"Clipping {param} to be positive")
135+
if param in clip_variables:
136+
# LOG.warning(f"Clipping {param} to be positive")
89137
x = np.maximum(x, 0)
90138

91139
assert x.shape == e.shape, (x.shape, e.shape)
92140

93-
check_data_values(x, name=param)
94141
out.write(x, template=template)
95142
template = None
96143

144+
assert len(seen) == len(members), (len(seen), len(members))
145+
97146
out.close()
98147

148+
if output is not None:
149+
return path
150+
99151
from climetlab import load_source
100152

101153
ds = load_source("file", path)

src/anemoi/datasets/create/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def patch(self, **kwargs):
9494

9595
apply_patch(self.path, **kwargs)
9696

97-
def init_additions(self, delta=[1, 3, 6, 12]):
97+
def init_additions(self, delta=[1, 3, 6, 12, 24]):
9898
from .loaders import StatisticsAddition
9999
from .loaders import TendenciesStatisticsAddition
100100
from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
@@ -109,7 +109,7 @@ def init_additions(self, delta=[1, 3, 6, 12]):
109109
except TendenciesStatisticsDeltaNotMultipleOfFrequency:
110110
self.print(f"Skipping delta={d} as it is not a multiple of the frequency.")
111111

112-
def run_additions(self, parts=None, delta=[1, 3, 6, 12]):
112+
def run_additions(self, parts=None, delta=[1, 3, 6, 12, 24]):
113113
from .loaders import StatisticsAddition
114114
from .loaders import TendenciesStatisticsAddition
115115
from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
@@ -124,7 +124,7 @@ def run_additions(self, parts=None, delta=[1, 3, 6, 12]):
124124
except TendenciesStatisticsDeltaNotMultipleOfFrequency:
125125
self.print(f"Skipping delta={d} as it is not a multiple of the frequency.")
126126

127-
def finalise_additions(self, delta=[1, 3, 6, 12]):
127+
def finalise_additions(self, delta=[1, 3, 6, 12, 24]):
128128
from .loaders import StatisticsAddition
129129
from .loaders import TendenciesStatisticsAddition
130130
from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency

src/anemoi/datasets/create/loaders.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -546,12 +546,17 @@ def write_stats_to_stdout(self, stats):
546546

547547

548548
class GenericAdditions(GenericDatasetHandler):
549-
def __init__(self, name="", **kwargs):
549+
def __init__(self, **kwargs):
550550
super().__init__(**kwargs)
551-
self.name = name
551+
self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True)
552+
553+
@property
554+
def tmp_storage_path(self):
555+
raise NotImplementedError
552556

553-
storage_path = f"{self.path}.tmp_storage_{name}"
554-
self.tmp_storage = build_storage(directory=storage_path, create=True)
557+
@property
558+
def final_storage_path(self):
559+
raise NotImplementedError
555560

556561
def initialise(self):
557562
self.tmp_storage.delete()
@@ -589,7 +594,7 @@ def finalise(self):
589594
count=np.full(shape, -1, dtype=np.int64),
590595
has_nans=np.full(shape, False, dtype=np.bool_),
591596
)
592-
LOG.info(f"Aggregating {self.name} statistics on shape={shape}. Variables : {self.variables}")
597+
LOG.info(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
593598

594599
found = set()
595600
ifound = set()
@@ -659,17 +664,18 @@ def finalise(self):
659664

660665
def _write(self, summary):
661666
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
662-
self._add_dataset(name=k, array=summary[k])
663-
self.registry.add_to_history("compute_statistics_end")
664-
LOG.info(f"Wrote {self.name} additions in {self.path}")
667+
name = self.final_storage_name(k)
668+
self._add_dataset(name=name, array=summary[k])
669+
self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
670+
LOG.info(f"Wrote additions in {self.path} ({self.final_storage_name('*')})")
665671

666672
def check_statistics(self):
667673
pass
668674

669675

670676
class StatisticsAddition(GenericAdditions):
671677
def __init__(self, **kwargs):
672-
super().__init__("statistics_", **kwargs)
678+
super().__init__(**kwargs)
673679

674680
z = zarr.open(self.path, mode="r")
675681
start = z.attrs["statistics_start_date"]
@@ -682,6 +688,13 @@ def __init__(self, **kwargs):
682688
assert len(self.variables) == self.ds.shape[1], self.ds.shape
683689
self.total = len(self.dates)
684690

691+
@property
692+
def tmp_storage_path(self):
693+
return f"{self.path}.tmp_storage_statistics"
694+
695+
def final_storage_name(self, k):
696+
return k
697+
685698
def run(self, parts):
686699
chunk_filter = ChunkFilter(parts=parts, total=self.total)
687700
for i in range(0, self.total):
@@ -725,8 +738,6 @@ class TendenciesStatisticsDeltaNotMultipleOfFrequency(ValueError):
725738

726739

727740
class TendenciesStatisticsAddition(GenericAdditions):
728-
DATASET_NAME_PATTERN = "statistics_tendencies_{delta}"
729-
730741
def __init__(self, path, delta=None, **kwargs):
731742
full_ds = open_dataset(path)
732743
self.variables = full_ds.variables
@@ -739,9 +750,10 @@ def __init__(self, path, delta=None, **kwargs):
739750
raise TendenciesStatisticsDeltaNotMultipleOfFrequency(
740751
f"Delta {delta} is not a multiple of frequency {frequency}"
741752
)
753+
self.delta = delta
742754
idelta = delta // frequency
743755

744-
super().__init__(path=path, name=self.DATASET_NAME_PATTERN.format(delta=f"{delta}h"), **kwargs)
756+
super().__init__(path=path, **kwargs)
745757

746758
z = zarr.open(self.path, mode="r")
747759
start = z.attrs["statistics_start_date"]
@@ -754,6 +766,21 @@ def __init__(self, path, delta=None, **kwargs):
754766
ds = open_dataset(self.path, start=start, end=end)
755767
self.ds = DeltaDataset(ds, idelta)
756768

769+
@property
770+
def tmp_storage_path(self):
771+
return f"{self.path}.tmp_storage_statistics_{self.delta}h"
772+
773+
def final_storage_name(self, k):
774+
return self.final_storage_name_from_delta(k, delta=self.delta)
775+
776+
@classmethod
777+
def final_storage_name_from_delta(_, k, delta):
778+
if isinstance(delta, int):
779+
delta = str(delta)
780+
if not delta.endswith("h"):
781+
delta = delta + "h"
782+
return f"statistics_tendencies_{delta}_{k}"
783+
757784
def run(self, parts):
758785
chunk_filter = ChunkFilter(parts=parts, total=self.total)
759786
for i in range(0, self.total):
@@ -768,9 +795,3 @@ def run(self, parts):
768795
self.tmp_storage.add([date, i, "missing"], key=date)
769796
self.tmp_storage.flush()
770797
LOG.info(f"Dataset {self.path} additions run.")
771-
772-
def _write(self, summary):
773-
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
774-
self._add_dataset(name=f"{self.name}_{k}", array=summary[k])
775-
self.registry.add_to_history(f"compute_{self.name}_end")
776-
LOG.info(f"Wrote {self.name} additions in {self.path}")

src/anemoi/datasets/create/size.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import logging
1111
import os
1212

13+
from anemoi.utils.humanize import bytes
14+
1315
from anemoi.datasets.create.utils import progress_bar
1416

1517
LOG = logging.getLogger(__name__)

src/anemoi/datasets/data/stores.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,12 @@ def statistics_tendencies(self, delta=None):
225225
delta = f"{delta}h"
226226
from anemoi.datasets.create.loaders import TendenciesStatisticsAddition
227227

228-
prefix = TendenciesStatisticsAddition.DATASET_NAME_PATTERN.format(delta=delta) + "_"
228+
func = TendenciesStatisticsAddition.final_storage_name_from_delta
229229
return dict(
230-
mean=self.z[f"{prefix}mean"][:],
231-
stdev=self.z[f"{prefix}stdev"][:],
232-
maximum=self.z[f"{prefix}maximum"][:],
233-
minimum=self.z[f"{prefix}minimum"][:],
230+
mean=self.z[func("mean", delta)][:],
231+
stdev=self.z[func("stdev", delta)][:],
232+
maximum=self.z[func("maximum", delta)][:],
233+
minimum=self.z[func("minimum", delta)][:],
234234
)
235235

236236
@property

0 commit comments

Comments
 (0)