Skip to content

Commit 64378e2

Browse files
committed
perturbations in anemoi.datasets.compute.*
1 parent 7961f16 commit 64378e2

File tree

3 files changed

+113
-96
lines changed

3 files changed

+113
-96
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ classifiers = [
4040
]
4141

4242
dependencies = [
43-
"anemoi-utils[provenance]",
43+
"anemoi-utils[provenance]>=0.1.7",
4444
"zarr",
4545
"pyyaml",
4646
"numpy",
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# (C) Copyright 2024 ECMWF.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
# In applying this licence, ECMWF does not waive the privileges and immunities
6+
# granted to it by virtue of its status as an intergovernmental organisation
7+
# nor does it submit to any jurisdiction.
8+
#
9+
10+
import warnings
11+
12+
import numpy as np
13+
from climetlab.core.temporary import temp_file
14+
from climetlab.readers.grib.output import new_grib_output
15+
16+
from anemoi.datasets.create.check import check_data_values
17+
from anemoi.datasets.create.functions import assert_is_fieldset
18+
19+
20+
def perturbations(
21+
members,
22+
center,
23+
positive_clipping_variables=[
24+
"q",
25+
"cp",
26+
"lsp",
27+
"tp",
28+
], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ?
29+
):
30+
31+
keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]
32+
33+
def check_compatible(f1, f2, ignore=["number"]):
34+
for k in keys + ["grid", "shape"]:
35+
if k in ignore:
36+
continue
37+
assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k))
38+
39+
print(f"Retrieving ensemble data with {members}")
40+
print(f"Retrieving center data with {center}")
41+
42+
members = members.order_by(*keys)
43+
center = center.order_by(*keys)
44+
45+
number_list = members.unique_values("number")["number"]
46+
n_numbers = len(number_list)
47+
48+
if len(center) * n_numbers != len(members):
49+
print(len(center), n_numbers, len(members))
50+
for f in members:
51+
print("Member: ", f)
52+
for f in center:
53+
print("Center: ", f)
54+
raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}")
55+
56+
# prepare output tmp file so we can read it back
57+
tmp = temp_file()
58+
path = tmp.path
59+
out = new_grib_output(path)
60+
61+
for i, center_field in enumerate(center):
62+
param = center_field.metadata("param")
63+
64+
# load the center field
65+
center_np = center_field.to_numpy()
66+
67+
# load the ensemble fields and compute the mean
68+
members_np = np.zeros((n_numbers, *center_np.shape))
69+
70+
for j in range(n_numbers):
71+
ensemble_field = members[i * n_numbers + j]
72+
check_compatible(center_field, ensemble_field)
73+
members_np[j] = ensemble_field.to_numpy()
74+
75+
mean_np = members_np.mean(axis=0)
76+
77+
for j in range(n_numbers):
78+
template = members[i * n_numbers + j]
79+
e = members_np[j]
80+
m = mean_np
81+
c = center_np
82+
83+
assert e.shape == c.shape == m.shape, (e.shape, c.shape, m.shape)
84+
85+
x = c - m + e
86+
87+
if param in positive_clipping_variables:
88+
warnings.warn(f"Clipping {param} to be positive")
89+
x = np.maximum(x, 0)
90+
91+
assert x.shape == e.shape, (x.shape, e.shape)
92+
93+
check_data_values(x, name=param)
94+
out.write(x, template=template)
95+
template = None
96+
97+
out.close()
98+
99+
from climetlab import load_source
100+
101+
ds = load_source("file", path)
102+
assert_is_fieldset(ds)
103+
# save a reference to the tmp file so it is deleted
104+
# only when the dataset is not used anymore
105+
ds._tmp = tmp
106+
107+
assert len(ds) == len(members), (len(ds), len(members))
108+
109+
return ds

src/anemoi/datasets/create/functions/sources/perturbations.py

Lines changed: 3 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,11 @@
66
# granted to it by virtue of its status as an intergovernmental organisation
77
# nor does it submit to any jurisdiction.
88
#
9-
import warnings
109
from copy import deepcopy
1110

12-
import numpy as np
13-
from climetlab.core.temporary import temp_file
14-
from climetlab.readers.grib.output import new_grib_output
11+
from anemoi.datasets.compute.perturbations import perturbations as compute_perturbations
1512

16-
from anemoi.datasets.create.check import check_data_values
17-
from anemoi.datasets.create.functions import assert_is_fieldset
18-
from anemoi.datasets.create.functions.actions.mars import mars
13+
from .mars import mars
1914

2015

2116
def to_list(x):
@@ -58,94 +53,7 @@ def load_if_needed(context, dates, dict_or_dataset):
5853
def perturbations(context, dates, members, center, remapping={}, patches={}):
5954
members = load_if_needed(context, dates, members)
6055
center = load_if_needed(context, dates, center)
61-
# return perturbations(member, centers....)
62-
63-
keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]
64-
65-
def check_compatible(f1, f2, ignore=["number"]):
66-
for k in keys + ["grid", "shape"]:
67-
if k in ignore:
68-
continue
69-
assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k))
70-
71-
print(f"Retrieving ensemble data with {members}")
72-
print(f"Retrieving center data with {center}")
73-
74-
members = members.order_by(*keys)
75-
center = center.order_by(*keys)
76-
77-
number_list = members.unique_values("number")["number"]
78-
n_numbers = len(number_list)
79-
80-
if len(center) * n_numbers != len(members):
81-
print(len(center), n_numbers, len(members))
82-
for f in members:
83-
print("Member: ", f)
84-
for f in center:
85-
print("Center: ", f)
86-
raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}")
87-
88-
# prepare output tmp file so we can read it back
89-
tmp = temp_file()
90-
path = tmp.path
91-
out = new_grib_output(path)
92-
93-
for i, center_field in enumerate(center):
94-
param = center_field.metadata("param")
95-
96-
# load the center field
97-
center_np = center_field.to_numpy()
98-
99-
# load the ensemble fields and compute the mean
100-
members_np = np.zeros((n_numbers, *center_np.shape))
101-
102-
for j in range(n_numbers):
103-
ensemble_field = members[i * n_numbers + j]
104-
check_compatible(center_field, ensemble_field)
105-
members_np[j] = ensemble_field.to_numpy()
106-
107-
mean_np = members_np.mean(axis=0)
108-
109-
for j in range(n_numbers):
110-
template = members[i * n_numbers + j]
111-
e = members_np[j]
112-
m = mean_np
113-
c = center_np
114-
115-
assert e.shape == c.shape == m.shape, (e.shape, c.shape, m.shape)
116-
117-
FORCED_POSITIVE = [
118-
"q",
119-
"cp",
120-
"lsp",
121-
"tp",
122-
] # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ?
123-
124-
x = c - m + e
125-
126-
if param in FORCED_POSITIVE:
127-
warnings.warn(f"Clipping {param} to be positive")
128-
x = np.maximum(x, 0)
129-
130-
assert x.shape == e.shape, (x.shape, e.shape)
131-
132-
check_data_values(x, name=param)
133-
out.write(x, template=template)
134-
template = None
135-
136-
out.close()
137-
138-
from climetlab import load_source
139-
140-
ds = load_source("file", path)
141-
assert_is_fieldset(ds)
142-
# save a reference to the tmp file so it is deleted
143-
# only when the dataset is not used anymore
144-
ds._tmp = tmp
145-
146-
assert len(ds) == len(members), (len(ds), len(members))
147-
148-
return ds
56+
return compute_perturbations(members, center)
14957

15058

15159
execute = perturbations

0 commit comments

Comments
 (0)