Skip to content

Commit 5e78201

Browse files
committed
update tests
1 parent a795999 commit 5e78201

File tree

2 files changed

+79
-45
lines changed

2 files changed

+79
-45
lines changed

src/anemoi/datasets/create/input.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -106,30 +106,32 @@ def _data_request(data):
106106
area = grid = None
107107

108108
for field in data:
109-
if not hasattr(field, "as_mars"):
110-
continue
111-
112-
if date is None:
113-
date = field.datetime()["valid_time"]
114-
115-
if field.datetime()["valid_time"] != date:
116-
continue
109+
try:
110+
if date is None:
111+
date = field.datetime()["valid_time"]
117112

118-
as_mars = field.metadata(namespace="mars")
119-
step = as_mars.get("step")
120-
levtype = as_mars.get("levtype", "sfc")
121-
param = as_mars["param"]
122-
levelist = as_mars.get("levelist", None)
123-
area = field.mars_area
124-
grid = field.mars_grid
113+
if field.datetime()["valid_time"] != date:
114+
continue
125115

126-
if levelist is None:
127-
params_levels[levtype].add(param)
128-
else:
129-
params_levels[levtype].add((param, levelist))
116+
as_mars = field.metadata(namespace="mars")
117+
if not as_mars:
118+
continue
119+
step = as_mars.get("step")
120+
levtype = as_mars.get("levtype", "sfc")
121+
param = as_mars["param"]
122+
levelist = as_mars.get("levelist", None)
123+
area = field.mars_area
124+
grid = field.mars_grid
125+
126+
if levelist is None:
127+
params_levels[levtype].add(param)
128+
else:
129+
params_levels[levtype].add((param, levelist))
130130

131-
if step:
132-
params_steps[levtype].add((param, step))
131+
if step:
132+
params_steps[levtype].add((param, step))
133+
except Exception:
134+
LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True)
133135

134136
def sort(old_dic):
135137
new_dic = {}

tests/create/test_create.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from anemoi.datasets.data.stores import open_zarr
2424

2525
TEST_DATA_ROOT = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/anemoi-datasets/create"
26+
TEST_DATA_S3_ROOT = "s3://ml-tests/test-data/anemoi-datasets/create"
2627

2728

2829
HERE = os.path.dirname(__file__)
@@ -94,29 +95,44 @@ def __call__(self, name, *args, **kwargs):
9495
_from_source = LoadSource()
9596

9697

97-
def compare_dot_zattrs(a, b):
98+
def compare_dot_zattrs(a, b, path, errors):
9899
if isinstance(a, dict):
99100
a_keys = list(a.keys())
100101
b_keys = list(b.keys())
101102
for k in set(a_keys) & set(b_keys):
102-
if k in ["timestamp", "uuid", "latest_write_timestamp", "yaml_config"]:
103-
assert type(a[k]) == type(b[k]), ( # noqa: E721
104-
type(a[k]),
105-
type(b[k]),
106-
a[k],
107-
b[k],
108-
)
109-
assert k in a_keys, (k, a_keys)
110-
assert k in b_keys, (k, b_keys)
111-
return compare_dot_zattrs(a[k], b[k])
103+
if k in [
104+
"timestamp",
105+
"uuid",
106+
"latest_write_timestamp",
107+
"yaml_config",
108+
"history",
109+
"provenance",
110+
"provenance_load",
111+
"description",
112+
"config_path",
113+
"dataset_status",
114+
]:
115+
if type(a[k]) != type(b[k]): # noqa : E721
116+
errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}")
117+
continue
118+
compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors)
119+
return
112120

113121
if isinstance(a, list):
114-
assert len(a) == len(b), (a, b)
115-
for v, w in zip(a, b):
116-
return compare_dot_zattrs(v, w)
117-
118-
assert type(a) == type(b), (type(a), type(b), a, b) # noqa: E721
119-
return a == b, (a, b)
122+
if len(a) != len(b):
123+
errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}")
124+
return
125+
for i, (v, w) in enumerate(zip(a, b)):
126+
compare_dot_zattrs(v, w, f"{path}.{i}", errors)
127+
return
128+
129+
if type(a) != type(b): # noqa : E721
130+
msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})"
131+
errors.append(msg)
132+
return
133+
if a != b:
134+
msg = f"❌ {path} actual != expected : {a} != {b}"
135+
errors.append(msg)
120136

121137

122138
def compare_datasets(a, b):
@@ -169,19 +185,29 @@ def compare_statistics(ds1, ds2):
169185
class Comparer:
170186
def __init__(self, name, output_path=None, reference_path=None):
171187
self.name = name
172-
self.reference = reference_path or os.path.join(TEST_DATA_ROOT, name + ".zarr")
173188
self.output = output_path or os.path.join(name + ".zarr")
174-
print(f"Comparing {self.reference} and {self.output}")
189+
self.reference_path = reference_path
190+
print(f"Comparing {self.output} and {self.reference_path}")
175191

176-
self.z_reference = open_zarr(self.reference)
177192
self.z_output = open_zarr(self.output)
193+
self.z_reference = open_zarr(self.reference_path)
178194

179-
self.ds_reference = open_dataset(self.reference)
195+
self.z_reference["data"]
180196
self.ds_output = open_dataset(self.output)
197+
self.ds_reference = open_dataset(self.reference_path)
181198

182199
def compare(self):
183-
compare_dot_zattrs(self.z_output.attrs, self.z_reference.attrs)
200+
errors = []
201+
compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors)
202+
if errors:
203+
print("Comparison failed")
204+
print("\n".join(errors))
205+
206+
if errors:
207+
raise AssertionError("Comparison failed")
208+
184209
compare_datasets(self.ds_output, self.ds_reference)
210+
185211
compare_statistics(self.ds_output, self.ds_reference)
186212

187213

@@ -199,8 +225,14 @@ def test_run(name):
199225
c.additions(delta=[1, 3, 6, 12])
200226
c.cleanup()
201227

202-
comparer = Comparer(name, output_path=output)
203-
comparer.compare()
228+
# reference_path = os.path.join(HERE, name + "-reference.zarr")
229+
s3_uri = TEST_DATA_S3_ROOT + "/" + name + ".zarr"
230+
# if not os.path.exists(reference_path):
231+
# from anemoi.utils.s3 import download as s3_download
232+
# s3_download(s3_uri + '/', reference_path, overwrite=True)
233+
234+
Comparer(name, output_path=output, reference_path=s3_uri).compare()
235+
# Comparer(name, output_path=output, reference_path=reference_path).compare()
204236

205237

206238
if __name__ == "__main__":

0 commit comments

Comments
 (0)