|
6 | 6 | # granted to it by virtue of its status as an intergovernmental organisation
|
7 | 7 | # nor does it submit to any jurisdiction.
|
8 | 8 | #
|
9 |
| -import warnings |
10 | 9 | from copy import deepcopy
|
11 | 10 |
|
12 |
| -import numpy as np |
13 |
| -from climetlab.core.temporary import temp_file |
14 |
| -from climetlab.readers.grib.output import new_grib_output |
| 11 | +from anemoi.datasets.compute.perturbations import perturbations as compute_perturbations |
15 | 12 |
|
16 |
| -from anemoi.datasets.create.check import check_data_values |
17 |
| -from anemoi.datasets.create.functions import assert_is_fieldset |
18 |
| -from anemoi.datasets.create.functions.actions.mars import mars |
| 13 | +from .mars import mars |
19 | 14 |
|
20 | 15 |
|
21 | 16 | def to_list(x):
|
@@ -58,94 +53,7 @@ def load_if_needed(context, dates, dict_or_dataset):
|
58 | 53 | def perturbations(context, dates, members, center, remapping={}, patches={}):
|
59 | 54 | members = load_if_needed(context, dates, members)
|
60 | 55 | center = load_if_needed(context, dates, center)
|
61 |
| - # return perturbations(member, centers....) |
62 |
| - |
63 |
| - keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"] |
64 |
| - |
65 |
| - def check_compatible(f1, f2, ignore=["number"]): |
66 |
| - for k in keys + ["grid", "shape"]: |
67 |
| - if k in ignore: |
68 |
| - continue |
69 |
| - assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k)) |
70 |
| - |
71 |
| - print(f"Retrieving ensemble data with {members}") |
72 |
| - print(f"Retrieving center data with {center}") |
73 |
| - |
74 |
| - members = members.order_by(*keys) |
75 |
| - center = center.order_by(*keys) |
76 |
| - |
77 |
| - number_list = members.unique_values("number")["number"] |
78 |
| - n_numbers = len(number_list) |
79 |
| - |
80 |
| - if len(center) * n_numbers != len(members): |
81 |
| - print(len(center), n_numbers, len(members)) |
82 |
| - for f in members: |
83 |
| - print("Member: ", f) |
84 |
| - for f in center: |
85 |
| - print("Center: ", f) |
86 |
| - raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}") |
87 |
| - |
88 |
| - # prepare output tmp file so we can read it back |
89 |
| - tmp = temp_file() |
90 |
| - path = tmp.path |
91 |
| - out = new_grib_output(path) |
92 |
| - |
93 |
| - for i, center_field in enumerate(center): |
94 |
| - param = center_field.metadata("param") |
95 |
| - |
96 |
| - # load the center field |
97 |
| - center_np = center_field.to_numpy() |
98 |
| - |
99 |
| - # load the ensemble fields and compute the mean |
100 |
| - members_np = np.zeros((n_numbers, *center_np.shape)) |
101 |
| - |
102 |
| - for j in range(n_numbers): |
103 |
| - ensemble_field = members[i * n_numbers + j] |
104 |
| - check_compatible(center_field, ensemble_field) |
105 |
| - members_np[j] = ensemble_field.to_numpy() |
106 |
| - |
107 |
| - mean_np = members_np.mean(axis=0) |
108 |
| - |
109 |
| - for j in range(n_numbers): |
110 |
| - template = members[i * n_numbers + j] |
111 |
| - e = members_np[j] |
112 |
| - m = mean_np |
113 |
| - c = center_np |
114 |
| - |
115 |
| - assert e.shape == c.shape == m.shape, (e.shape, c.shape, m.shape) |
116 |
| - |
117 |
| - FORCED_POSITIVE = [ |
118 |
| - "q", |
119 |
| - "cp", |
120 |
| - "lsp", |
121 |
| - "tp", |
122 |
| - ] # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ? |
123 |
| - |
124 |
| - x = c - m + e |
125 |
| - |
126 |
| - if param in FORCED_POSITIVE: |
127 |
| - warnings.warn(f"Clipping {param} to be positive") |
128 |
| - x = np.maximum(x, 0) |
129 |
| - |
130 |
| - assert x.shape == e.shape, (x.shape, e.shape) |
131 |
| - |
132 |
| - check_data_values(x, name=param) |
133 |
| - out.write(x, template=template) |
134 |
| - template = None |
135 |
| - |
136 |
| - out.close() |
137 |
| - |
138 |
| - from climetlab import load_source |
139 |
| - |
140 |
| - ds = load_source("file", path) |
141 |
| - assert_is_fieldset(ds) |
142 |
| - # save a reference to the tmp file so it is deleted |
143 |
| - # only when the dataset is not used anymore |
144 |
| - ds._tmp = tmp |
145 |
| - |
146 |
| - assert len(ds) == len(members), (len(ds), len(members)) |
147 |
| - |
148 |
| - return ds |
| 56 | + return compute_perturbations(members, center) |
149 | 57 |
|
150 | 58 |
|
151 | 59 | execute = perturbations
|
0 commit comments